This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a2bd7191dc5b [SPARK-55338][CONNECT] Centralize Spark Connect request 
decompression logic in gRPC interceptor
a2bd7191dc5b is described below

commit a2bd7191dc5ba6672789de42339aabef8cf6b66c
Author: Xi Lyu <[email protected]>
AuthorDate: Wed Mar 4 15:11:06 2026 -0400

    [SPARK-55338][CONNECT] Centralize Spark Connect request decompression logic 
in gRPC interceptor
    
    ### What changes were proposed in this pull request?
    
    Previously, decompression logic introduced by 
https://github.com/apache/spark/pull/52894 was scattered across multiple 
handlers (AnalyzePlanHandler, ExecutePlanHandler, and probably other handlers 
in the future), with each handler responsible for decompressing its own 
compressed requests. This approach made it difficult to maintain consistent 
decompression behavior across different RPC types.
    
    This PR introduces RequestDecompressionInterceptor as a gRPC interceptor 
that handles decompression for all compressed requests before they reach 
downstream handlers.
    
    To improve debuggability on decompression errors,
    * Enriched exceptions before throwing back to the client, so the correct 
error class with error message will be propagated to the client, like other 
errors thrown in handlers.
    * Added more logs. On error, the driver logs will have detailed information 
for debugging.
    
    ### Why are the changes needed?
    
    This refactoring ensures single responsibility and consistency, as 
decompression logic is now in one place, making it easier to maintain.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests for the interceptor and existing tests for plan compression.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54113 from xi-db/move-decompression-into-a-single-interceptor.
    
    Authored-by: Xi Lyu <[email protected]>
    Signed-off-by: Herman van Hövell <[email protected]>
---
 .../service/RequestDecompressionInterceptor.scala  | 331 ++++++++++++++++++++
 .../service/SparkConnectAnalyzeHandler.scala       |  25 +-
 .../service/SparkConnectExecutePlanHandler.scala   |  20 +-
 .../service/SparkConnectInterceptorRegistry.scala  |   4 +-
 .../spark/sql/connect/utils/ErrorUtils.scala       | 186 +++++++----
 .../RequestDecompressionInterceptorSuite.scala     | 346 +++++++++++++++++++++
 6 files changed, 845 insertions(+), 67 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/RequestDecompressionInterceptor.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/RequestDecompressionInterceptor.scala
new file mode 100644
index 000000000000..d93dc2069cf6
--- /dev/null
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/RequestDecompressionInterceptor.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.concurrent.atomic.AtomicReference
+
+import scala.util.control.NonFatal
+
+import io.grpc.{Context, Contexts, Metadata, ServerCall, ServerCallHandler, 
ServerInterceptor}
+import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.{BYTE_SIZE, SESSION_ID, USER_ID}
+import org.apache.spark.sql.connect.utils.{ErrorUtils, PlanCompressionUtils}
+
+/**
+ * Interceptor that decompresses compressed requests before they reach 
downstream handlers.
+ *
+ * This interceptor currently handles:
+ *   - ExecutePlanRequest with compressed plans
+ *   - AnalyzePlanRequest with compressed plans (Schema, Explain, TreeString, 
IsLocal,
+ *     IsStreaming, InputFiles, SemanticHash, and SameSemantics analysis types)
+ *
+ * Compressed plan size metrics are tracked in gRPC Context for use by 
handlers.
+ */
+class RequestDecompressionInterceptor extends ServerInterceptor with Logging {
+
+  override def interceptCall[ReqT, RespT](
+      call: ServerCall[ReqT, RespT],
+      headers: Metadata,
+      next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
+
+    // Create an AtomicReference to hold compressed sizes
+    val compressedSizesRef = new AtomicReference[Seq[Option[Long]]](Seq.empty)
+
+    // Create context with the AtomicReference and start the call within that 
context
+    val ctx = Context
+      .current()
+      .withValue(RequestDecompressionContext.COMPRESSED_SIZES_KEY, 
compressedSizesRef)
+    val listener = Contexts.interceptCall(ctx, call, headers, next)
+
+    new SimpleForwardingServerCallListener[ReqT](listener) {
+      override def onMessage(message: ReqT): Unit = {
+        message match {
+          case req: proto.ExecutePlanRequest =>
+            handleRequestWithDecompression(
+              compressedSizesRef,
+              req.getUserContext.getUserId,
+              req.getSessionId,
+              () => decompressExecutePlanRequest(req))
+
+          case req: proto.AnalyzePlanRequest =>
+            handleRequestWithDecompression(
+              compressedSizesRef,
+              req.getUserContext.getUserId,
+              req.getSessionId,
+              () => decompressAnalyzePlanRequest(req))
+
+          case other =>
+            // Forward all other message types as-is (no decompression or 
error handling needed)
+            super.onMessage(other)
+        }
+      }
+
+      private def handleRequestWithDecompression[T](
+          compressedSizesRef: AtomicReference[Seq[Option[Long]]],
+          userId: String,
+          sessionId: String,
+          decompressRequest: () => (T, Seq[Option[Long]])): Unit = {
+        val (decompressedReq, compressedSizes) =
+          try {
+            decompressRequest()
+          } catch {
+            case NonFatal(e) =>
+              // Handle decompression errors
+              logError(
+                log"Plan decompression failed: " +
+                  log"userId=${MDC(USER_ID, userId)}, " +
+                  log"sessionId=${MDC(SESSION_ID, sessionId)}",
+                e)
+              ErrorUtils.handleError("planDecompression", call, userId, 
sessionId)(e)
+              return
+          }
+
+        // Set compressed sizes in the AtomicReference
+        compressedSizesRef.set(compressedSizes)
+        super.onMessage(decompressedReq.asInstanceOf[ReqT])
+      }
+    }
+  }
+
+  private def decompressExecutePlanRequest(
+      request: proto.ExecutePlanRequest): (proto.ExecutePlanRequest, 
Seq[Option[Long]]) = {
+    if (!request.hasPlan) {
+      return (request, Seq.empty)
+    }
+    decompressPlan(
+      request.getPlan,
+      request.getUserContext.getUserId,
+      request.getSessionId) match {
+      case Some((plan, size)) =>
+        (request.toBuilder.setPlan(plan).build(), Seq(Some(size)))
+      case None =>
+        (request, Seq(None))
+    }
+  }
+
+  private def decompressAnalyzePlanRequest(
+      request: proto.AnalyzePlanRequest): (proto.AnalyzePlanRequest, 
Seq[Option[Long]]) = {
+    val userId = request.getUserContext.getUserId
+    val sessionId = request.getSessionId
+
+    // Helper: decompress a plan and rebuild the request only if compressed
+    def decompress(
+        req: proto.AnalyzePlanRequest,
+        plan: proto.Plan,
+        rebuild: proto.Plan => proto.AnalyzePlanRequest)
+        : (proto.AnalyzePlanRequest, Option[Long]) = {
+      decompressPlan(plan, userId, sessionId) match {
+        case Some((decompressedPlan, size)) => (rebuild(decompressedPlan), 
Some(size))
+        case None => (req, None)
+      }
+    }
+
+    // NOTE: All AnalyzePlanRequest cases are explicitly listed here.
+    // The default case throws an exception to catch new cases at runtime and 
fail CI tests.
+    request.getAnalyzeCase match {
+      // Cases with Plan fields - decompress if compressed
+      case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
+        val (req, size) = decompress(
+          request,
+          request.getSchema.getPlan,
+          p => 
request.toBuilder.setSchema(request.getSchema.toBuilder.setPlan(p)).build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
+        val (req, size) = decompress(
+          request,
+          request.getExplain.getPlan,
+          p => 
request.toBuilder.setExplain(request.getExplain.toBuilder.setPlan(p)).build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
+        val (req, size) = decompress(
+          request,
+          request.getTreeString.getPlan,
+          p =>
+            request.toBuilder
+              .setTreeString(request.getTreeString.toBuilder.setPlan(p))
+              .build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
+        val (req, size) = decompress(
+          request,
+          request.getIsLocal.getPlan,
+          p => 
request.toBuilder.setIsLocal(request.getIsLocal.toBuilder.setPlan(p)).build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
+        val (req, size) = decompress(
+          request,
+          request.getIsStreaming.getPlan,
+          p =>
+            request.toBuilder
+              .setIsStreaming(request.getIsStreaming.toBuilder.setPlan(p))
+              .build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
+        val (req, size) = decompress(
+          request,
+          request.getInputFiles.getPlan,
+          p =>
+            request.toBuilder
+              .setInputFiles(request.getInputFiles.toBuilder.setPlan(p))
+              .build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
+        val (req, size) = decompress(
+          request,
+          request.getSemanticHash.getPlan,
+          p =>
+            request.toBuilder
+              .setSemanticHash(request.getSemanticHash.toBuilder.setPlan(p))
+              .build())
+        (req, Seq(size))
+
+      case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
+        // Special case: has two Plan fields (target_plan and other_plan)
+        val (reqWithTarget, targetSize) = decompress(
+          request,
+          request.getSameSemantics.getTargetPlan,
+          p =>
+            request.toBuilder
+              
.setSameSemantics(request.getSameSemantics.toBuilder.setTargetPlan(p))
+              .build())
+        val (finalReq, otherSize) = decompress(
+          reqWithTarget,
+          reqWithTarget.getSameSemantics.getOtherPlan,
+          p =>
+            reqWithTarget.toBuilder
+              
.setSameSemantics(reqWithTarget.getSameSemantics.toBuilder.setOtherPlan(p))
+              .build())
+        (finalReq, Seq(targetSize, otherSize))
+
+      // Cases with Relation fields - currently not compressed
+      case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST |
+          proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST |
+          proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
+        (request, Seq.empty)
+
+      // Cases with no Plan or Relation fields - safe to pass through
+      case proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION |
+          proto.AnalyzePlanRequest.AnalyzeCase.DDL_PARSE |
+          proto.AnalyzePlanRequest.AnalyzeCase.JSON_TO_DDL =>
+        (request, Seq.empty)
+
+      // No analysis case set - safe to pass through, will be handled in 
handler
+      case proto.AnalyzePlanRequest.AnalyzeCase.ANALYZE_NOT_SET =>
+        (request, Seq.empty)
+
+      case _ =>
+        // Unhandled case - fail to catch new cases during testing
+        throw new UnsupportedOperationException(
+          s"Unhandled AnalyzePlanRequest case: ${request.getAnalyzeCase}. " +
+            s"RequestDecompressionInterceptor must be updated to handle this 
case explicitly. " +
+            s"If the case contains Plan fields, add decompression logic. " +
+            s"Otherwise, add it to the safe passthrough cases.")
+    }
+  }
+
+  /**
+   * Decompresses a plan if it contains a compressed operation. Returns 
Some((decompressedPlan,
+   * compressedSize)) if compressed, None otherwise.
+   */
+  private def decompressPlan(
+      plan: proto.Plan,
+      userId: String,
+      sessionId: String): Option[(proto.Plan, Long)] = {
+    if (plan.getOpTypeCase != proto.Plan.OpTypeCase.COMPRESSED_OPERATION) {
+      return None
+    }
+    val compressedSize = plan.getCompressedOperation.getData.size().toLong
+    logInfo(
+      log"Received compressed plan " +
+        log"(size=${MDC(BYTE_SIZE, compressedSize)} bytes): " +
+        log"userId=${MDC(USER_ID, userId)}, sessionId=${MDC(SESSION_ID, 
sessionId)}")
+
+    val decompressedPlan = PlanCompressionUtils.decompressPlan(plan)
+    logInfo(
+      log"Plan decompression completed " +
+        log"(compressed=${MDC(BYTE_SIZE, compressedSize)} bytes -> " +
+        log"decompressed=${MDC(BYTE_SIZE, 
decompressedPlan.getSerializedSize.toLong)} bytes, " +
+        log"userId=${MDC(USER_ID, userId)}, sessionId=${MDC(SESSION_ID, 
sessionId)}")
+
+    Some((decompressedPlan, compressedSize))
+  }
+}
+
+/**
+ * Context holder for passing decompression metrics from interceptor to 
handlers. Uses gRPC
+ * Context to properly propagate values.
+ */
+object RequestDecompressionContext {
+
+  /**
+   * Context key for storing compressed sizes. This is set by 
RequestDecompressionInterceptor when
+   * compressed requests are encountered, and read by handlers for metrics.
+   *
+   * The sequence contains Option[Long] entries corresponding to each plan in 
the request:
+   *   - For ExecutePlan and single-plan Analyze requests: Seq(Some(size)) if 
compressed, or
+   *     Seq.empty if not
+   *   - For AnalyzePlanRequest SameSemantics with two plans: Seq(target_size, 
other_size) where
+   *     each is Some(size) if that plan was compressed, None if not
+   *
+   * The sequence length matches the number of plans, with explicit None for 
uncompressed plans.
+   */
+  val COMPRESSED_SIZES_KEY: Context.Key[AtomicReference[Seq[Option[Long]]]] =
+    Context.key("compressed-sizes")
+
+  /**
+   * Get all compressed sizes from the current gRPC context. Returns empty 
sequence if no
+   * compressed sizes were set.
+   */
+  def getCompressedSizes: Seq[Option[Long]] = {
+    Option(COMPRESSED_SIZES_KEY.get()) match {
+      case Some(ref) => Option(ref.get()).getOrElse(Seq.empty)
+      case None => Seq.empty
+    }
+  }
+
+  /**
+   * Get the first compressed size from the current gRPC context. Returns None 
if no compressed
+   * size was set (request was not compressed).
+   *
+   * This is the primary size for ExecutePlan and single-plan Analyze 
requests, and the
+   * target_plan size for SameSemantics requests.
+   */
+  def getCompressedSize: Option[Long] = {
+    getCompressedSizes.headOption.flatten
+  }
+
+  /**
+   * Get the second compressed size from the current gRPC context. Returns 
None if no second
+   * compressed size was set.
+   *
+   * This is only set for AnalyzePlanRequest SameSemantics with a compressed 
other_plan.
+   */
+  def getOtherCompressedSize: Option[Long] = {
+    val sizes = getCompressedSizes
+    if (sizes.length > 1) sizes(1) else None
+  }
+}
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index ec8d95271c76..72029cafaa63 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -23,13 +23,14 @@ import io.grpc.stub.StreamObserver
 
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.BYTE_SIZE
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.classic.{DataFrame, Dataset}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
InvalidPlanInput, StorageLevelProtoConverter}
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
-import org.apache.spark.sql.connect.utils.{PipelineAnalysisContextUtils, 
PlanCompressionUtils}
+import org.apache.spark.sql.connect.utils.PipelineAnalysisContextUtils
 import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, 
CostMode, ExtendedMode, FormattedMode, SimpleMode}
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.util.ArrayImplicits._
@@ -69,11 +70,23 @@ private[connect] class SparkConnectAnalyzeHandler(
       throw new 
AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map())
     }
 
-    def transformRelation(rel: proto.Relation) = 
planner.transformRelation(rel, cachePlan = true)
-    def transformRelationPlan(plan: proto.Plan) = {
-      transformRelation(PlanCompressionUtils.decompressPlan(plan).getRoot)
+    // Log compressed sizes from gRPC Context (set by 
RequestDecompressionInterceptor)
+    val compressedSize = RequestDecompressionContext.getCompressedSize
+    val otherCompressedSize = 
RequestDecompressionContext.getOtherCompressedSize
+    if (compressedSize.isDefined || otherCompressedSize.isDefined) {
+      logDebug(
+        log"AnalyzePlan request received with compressed plan: " +
+          log"compressedSize=${MDC(BYTE_SIZE, compressedSize.getOrElse(0L))} 
bytes" +
+          otherCompressedSize
+            .map { size =>
+              log", otherCompressedSize=${MDC(BYTE_SIZE, size)} bytes"
+            }
+            .getOrElse(log""))
     }
 
+    def transformRelation(rel: proto.Relation) = 
planner.transformRelation(rel, cachePlan = true)
+    def transformRelationPlan(plan: proto.Plan) = 
transformRelation(plan.getRoot)
+
     def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = {
       val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP)
       new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
@@ -227,6 +240,10 @@ private[connect] class SparkConnectAnalyzeHandler(
             .setDdlString(ddl)
             .build())
 
+      // NOTE: When adding a new AnalyzePlanRequest case here, also update
+      // RequestDecompressionInterceptor.decompressAnalyzePlanRequest() to 
handle
+      // this case. The interceptor has a default case that throws 
UnsupportedOperationException
+      // for unhandled cases, which will fail tests and block CI if you forget 
to update it.
       case other => throw InvalidPlanInput(s"Unknown Analyze Method $other!")
     }
 
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
index 6780ca37e96a..24a3d0541f55 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
@@ -22,7 +22,7 @@ import io.grpc.stub.StreamObserver
 import org.apache.spark.SparkSQLException
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.connect.utils.PlanCompressionUtils
+import org.apache.spark.internal.LogKeys.BYTE_SIZE
 
 class SparkConnectExecutePlanHandler(responseObserver: 
StreamObserver[proto.ExecutePlanResponse])
     extends Logging {
@@ -36,20 +36,20 @@ class SparkConnectExecutePlanHandler(responseObserver: 
StreamObserver[proto.Exec
       .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId, 
previousSessionId)
     val executeKey = ExecuteKey(v, sessionHolder)
 
-    val decompressedRequest =
-      
v.toBuilder.setPlan(PlanCompressionUtils.decompressPlan(v.getPlan)).build()
+    // Log compressed sizes from gRPC Context (set by 
RequestDecompressionInterceptor)
+    val compressedSize = RequestDecompressionContext.getCompressedSize
+    if (compressedSize.isDefined) {
+      logDebug(
+        log"ExecutePlan request received with compressed plan: " +
+          log"compressedSize=${MDC(BYTE_SIZE, compressedSize.get)} bytes")
+    }
 
     SparkConnectService.executionManager.getExecuteHolder(executeKey) match {
       case None =>
         // Create a new execute holder and attach to it.
         SparkConnectService.executionManager
-          .createExecuteHolderAndAttach(
-            executeKey,
-            decompressedRequest,
-            sessionHolder,
-            responseObserver)
-      case Some(executeHolder)
-          if executeHolder.request.getPlan.equals(decompressedRequest.getPlan) 
=>
+          .createExecuteHolderAndAttach(executeKey, v, sessionHolder, 
responseObserver)
+      case Some(executeHolder) if 
executeHolder.request.getPlan.equals(v.getPlan) =>
         // If the execute holder already exists with the same plan, reattach 
to it.
         SparkConnectService.executionManager
           .reattachExecuteHolder(executeHolder, responseObserver, None)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala
index 90759c00ccfc..8979d8ba8e60 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala
@@ -39,7 +39,9 @@ object SparkConnectInterceptorRegistry {
     // Adding a new interceptor at compile time works like the example below 
with the dummy
     // interceptor:
     // interceptor[DummyInterceptor](classOf[DummyInterceptor])
-  )
+
+    // Request decompression interceptor handles compressed requests from 
clients.
+    
interceptor[RequestDecompressionInterceptor](classOf[RequestDecompressionInterceptor]))
 
   /**
    * Given a NettyServerBuilder instance, will chain all interceptors to it in 
reverse order.
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 40a1f647f7c9..619fefe2e7c6 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
 
 import com.google.protobuf.{Any => ProtoAny}
 import com.google.rpc.{Code => RPCCode, ErrorInfo, Status => RPCStatus}
-import io.grpc.Status
+import io.grpc.{Metadata, ServerCall, Status, StatusRuntimeException}
 import io.grpc.protobuf.StatusProto
 import io.grpc.stub.StreamObserver
 import org.json4s.JsonDSL._
@@ -281,6 +281,97 @@ private[connect] object ErrorUtils extends Logging {
       .exists(_.toString.contains("org.apache.spark.sql.execution.python"))
   }
 
+  /**
+   * Process an error by retrieving session context, converting to gRPC 
status, logging, posting
+   * events, and executing callbacks. This is the core error handling logic 
shared by both
+   * StreamObserver and ServerCall variants.
+   *
+   * @param opType
+   *   The operation type (analysis, execution, planDecompression, etc.)
+   * @param userId
+   *   The user id
+   * @param sessionId
+   *   The session id
+   * @param st
+   *   The throwable to process
+   * @param events
+   *   Optional ExecuteEventsManager to report failures (None for interceptors)
+   * @param isInterrupted
+   *   Whether the error was caused by interruption
+   * @param callback
+   *   Optional callback to execute after processing
+   * @return
+   *   Tuple of (original throwable, wrapped StatusRuntimeException)
+   */
+  private def processErrorCommon(
+      opType: String,
+      userId: String,
+      sessionId: String,
+      st: Throwable,
+      events: Option[ExecuteEventsManager] = None,
+      isInterrupted: Boolean = false,
+      callback: Option[() => Unit] = None): (Throwable, 
StatusRuntimeException) = {
+
+    // SessionHolder may not be present, e.g. if the session was already 
closed.
+    // When SessionHolder is not present error details will not be available 
for FetchErrorDetails.
+    val sessionHolderOpt =
+      SparkConnectService.sessionManager.getIsolatedSessionIfPresent(
+        SessionKey(userId, sessionId))
+    if (sessionHolderOpt.isEmpty) {
+      logWarning(
+        log"SessionHolder not found during error handling for " +
+          log"${MDC(OP_TYPE, opType)}. " +
+          log"UserId: ${MDC(USER_ID, userId)}, SessionId: ${MDC(SESSION_ID, 
sessionId)}. " +
+          log"Error details will not be available for FetchErrorDetails.")
+    }
+
+    // Convert throwable to StatusRuntimeException with appropriate error 
metadata
+    val wrapped: StatusRuntimeException = st match {
+      case se: SparkException if isPythonExecutionException(se) =>
+        StatusProto.toStatusRuntimeException(
+          buildStatusFromThrowable(se.getCause, sessionHolderOpt))
+
+      case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) 
=>
+        StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, 
sessionHolderOpt))
+
+      case e: Throwable =>
+        Status.UNKNOWN
+          .withCause(e)
+          .withDescription(Utils.abbreviate(e.getMessage, 2048))
+          .asRuntimeException()
+    }
+
+    // Log the error based on context
+    if (events.isDefined) {
+      // Errors thrown inside execution are user query errors, return then as 
INFO.
+      logInfo(
+        log"Spark Connect error during: ${MDC(OP_TYPE, opType)}. " +
+          log"UserId: ${MDC(USER_ID, userId)}. SessionId: ${MDC(SESSION_ID, 
sessionId)}.",
+        st)
+    } else {
+      // Other errors are server RPC errors, return them as ERROR.
+      logError(
+        log"Spark Connect RPC error during: ${MDC(OP_TYPE, opType)}. " +
+          log"UserId: ${MDC(USER_ID, userId)}. SessionId: ${MDC(SESSION_ID, 
sessionId)}.",
+        st)
+    }
+
+    // If ExecuteEventsManager is present, this is an execution error that 
needs to be
+    // posted to it.
+    events.foreach { executeEventsManager =>
+      if (isInterrupted) {
+        executeEventsManager.postCanceled()
+      } else {
+        executeEventsManager.postFailed(wrapped.getMessage)
+      }
+    }
+
+    // Execute callback if present
+    callback.foreach(_.apply())
+
+    (st, wrapped)
+  }
+
   /**
    * Common exception handling function for RPC methods. Closes the stream 
after the error has
    * been sent.
@@ -311,58 +402,49 @@ private[connect] object ErrorUtils extends Logging {
       events: Option[ExecuteEventsManager] = None,
       isInterrupted: Boolean = false,
       callback: Option[() => Unit] = None): PartialFunction[Throwable, Unit] = 
{
-
-    // SessionHolder may not be present, e.g. if the session was already 
closed.
-    // When SessionHolder is not present error details will not be available 
for FetchErrorDetails.
-    val sessionHolderOpt =
-      SparkConnectService.sessionManager.getIsolatedSessionIfPresent(
-        SessionKey(userId, sessionId))
-
-    val partial: PartialFunction[Throwable, (Throwable, Throwable)] = {
-      case se: SparkException if isPythonExecutionException(se) =>
-        (
-          se,
-          StatusProto.toStatusRuntimeException(
-            buildStatusFromThrowable(se.getCause, sessionHolderOpt)))
-
-      case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) 
=>
-        (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, 
sessionHolderOpt)))
-
-      case e: Throwable =>
-        (
-          e,
-          Status.UNKNOWN
-            .withCause(e)
-            .withDescription(Utils.abbreviate(e.getMessage, 2048))
-            .asRuntimeException())
+    { case st: Throwable =>
+      val (_, wrapped) =
+        processErrorCommon(opType, userId, sessionId, st, events, 
isInterrupted, callback)
+      observer.onError(wrapped)
     }
-    partial
-      .andThen { case (original, wrapped) =>
-        if (events.isDefined) {
-          // Errors thrown inside execution are user query errors, return then 
as INFO.
-          logInfo(
-            log"Spark Connect error during: ${MDC(OP_TYPE, opType)}. " +
-              log"UserId: ${MDC(USER_ID, userId)}. SessionId: 
${MDC(SESSION_ID, sessionId)}.",
-            original)
-        } else {
-          // Other errors are server RPC errors, return them as ERROR.
-          logError(
-            log"Spark Connect RPC error during: ${MDC(OP_TYPE, opType)}. " +
-              log"UserId: ${MDC(USER_ID, userId)}. SessionId: 
${MDC(SESSION_ID, sessionId)}.",
-            original)
-        }
+  }
 
-        // If ExecuteEventsManager is present, this this is an execution error 
that needs to be
-        // posted to it.
-        events.foreach { executeEventsManager =>
-          if (isInterrupted) {
-            executeEventsManager.postCanceled()
-          } else {
-            executeEventsManager.postFailed(wrapped.getMessage)
-          }
-        }
-        callback.foreach(_.apply())
-        observer.onError(wrapped)
-      }
+  /**
+   * Common exception handling function for interceptor-level errors. Closes 
the ServerCall after
+   * the error has been sent.
+   *
+   * Note: Interceptors typically pass events=None since ExecuteEventsManager 
is not available at
+   * the interceptor level.
+   *
+   * @param opType
+   *   String value indicating the operation type (planDecompression, etc.)
+   * @param call
+   *   The ServerCall to close with error status
+   * @param userId
+   *   The user id
+   * @param sessionId
+   *   The session id
+   * @tparam ReqT
+   *   Request type
+   * @tparam RespT
+   *   Response type
+   * @return
+   *   PartialFunction for error handling
+   */
+  def handleError[ReqT, RespT](
+      opType: String,
+      call: ServerCall[ReqT, RespT],
+      userId: String,
+      sessionId: String): PartialFunction[Throwable, Unit] = { { case st: 
Throwable =>
+    // Include method name in opType for better error logging
+    val methodName = call.getMethodDescriptor.getBareMethodName
+    val opTypeWithMethod = s"$opType [$methodName]"
+    val (_, wrapped) = processErrorCommon(opTypeWithMethod, userId, sessionId, 
st)
+
+    // Close ServerCall with error status and trailers
+    val status = wrapped.getStatus
+    val trailers = Option(wrapped.getTrailers).getOrElse(new Metadata())
+    call.close(status, trailers)
+  }
   }
 }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/RequestDecompressionInterceptorSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/RequestDecompressionInterceptorSuite.scala
new file mode 100644
index 000000000000..2c38d943f5be
--- /dev/null
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/RequestDecompressionInterceptorSuite.scala
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.util.Random
+
+import com.github.luben.zstd.Zstd
+import com.google.protobuf.ByteString
+import io.grpc.{Metadata, ServerCall, ServerCallHandler, Status}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.test.SharedSparkSession
+
+class RequestDecompressionInterceptorSuite extends SparkFunSuite with 
SharedSparkSession {
+  private val testUserId = "testUserId"
+  private val testSessionId = UUID.randomUUID().toString
+  private val testUserCtx = proto.UserContext
+    .newBuilder()
+    .setUserId(testUserId)
+    .setUserName("testUserName")
+    .build()
+
+  // Helper: Create a compressed plan
+  private def createCompressedPlan(query: String): proto.Plan = {
+    val relation = proto.Relation
+      .newBuilder()
+      .setSql(proto.SQL.newBuilder().setQuery(query))
+      
.setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build())
+      .build()
+
+    val compressedBytes = Zstd.compress(relation.toByteArray)
+    proto.Plan
+      .newBuilder()
+      .setCompressedOperation(
+        proto.Plan.CompressedOperation
+          .newBuilder()
+          .setData(ByteString.copyFrom(compressedBytes))
+          .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+          .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+          .build())
+      .build()
+  }
+
+  // Helper: Create an ExecutePlanRequest with a plan
+  private def createExecutePlanRequest(plan: proto.Plan): 
proto.ExecutePlanRequest = {
+    proto.ExecutePlanRequest
+      .newBuilder()
+      .setPlan(plan)
+      .setSessionId(testSessionId)
+      .setUserContext(testUserCtx)
+      .build()
+  }
+
+  // Mock ServerCall to capture close() calls
+  private class TestServerCall extends ServerCall[Any, Any] {
+    var closedStatus: Status = _
+    var closedTrailers: Metadata = _
+
+    // Dummy marshaller for testing purposes
+    private val dummyMarshaller = new io.grpc.MethodDescriptor.Marshaller[Any] 
{
+      override def stream(value: Any): java.io.InputStream = null
+      override def parse(stream: java.io.InputStream): Any = null
+    }
+
+    private val descriptor = io.grpc.MethodDescriptor
+      .newBuilder[Any, Any]()
+      .setType(io.grpc.MethodDescriptor.MethodType.UNARY)
+      .setFullMethodName("spark.connect.SparkConnectService/ExecutePlan")
+      .setRequestMarshaller(dummyMarshaller)
+      .setResponseMarshaller(dummyMarshaller)
+      .build()
+
+    override def getMethodDescriptor: io.grpc.MethodDescriptor[Any, Any] = 
descriptor
+    override def close(status: Status, trailers: Metadata): Unit = {
+      closedStatus = status
+      closedTrailers = trailers
+    }
+    override def request(numMessages: Int): Unit = {}
+    override def sendHeaders(headers: Metadata): Unit = {}
+    override def sendMessage(message: Any): Unit = {}
+    override def isReady: Boolean = true
+    override def isCancelled: Boolean = false
+    override def setMessageCompression(enabled: Boolean): Unit = {}
+    override def setCompression(compressor: String): Unit = {}
+    override def getAttributes: io.grpc.Attributes = null
+    override def getAuthority: String = "localhost"
+  }
+
+  // Mock ServerCallHandler
+  private class TestHandler extends ServerCallHandler[Any, Any] {
+    var receivedMessage: Any = _
+    var capturedCompressedSize: Option[Long] = None
+
+    override def startCall(
+        call: ServerCall[Any, Any],
+        headers: Metadata): ServerCall.Listener[Any] = {
+      new ServerCall.Listener[Any] {
+        override def onMessage(message: Any): Unit = {
+          receivedMessage = message
+          // Capture the context value while it's still attached
+          capturedCompressedSize = 
RequestDecompressionContext.getCompressedSize
+        }
+      }
+    }
+  }
+
+  test("decompresses compressed ExecutePlanRequest") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val compressedPlan = createCompressedPlan(s"select ${"Apache Spark" * 
10000} as value")
+    val request = createExecutePlanRequest(compressedPlan)
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    // Verify the handler received a decompressed request
+    assert(handler.receivedMessage != null)
+    val decompressedRequest = 
handler.receivedMessage.asInstanceOf[proto.ExecutePlanRequest]
+    assert(!decompressedRequest.getPlan.hasCompressedOperation)
+    assert(decompressedRequest.getPlan.hasRoot) // Plan should be decompressed
+
+    // Verify compressed plan size was captured in context
+    assert(handler.capturedCompressedSize.isDefined)
+    assert(handler.capturedCompressedSize.get > 0)
+  }
+
+  test("passes through non-compressed ExecutePlanRequest unchanged") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val normalPlan = proto.Plan
+      .newBuilder()
+      .setRoot(
+        proto.Relation
+          .newBuilder()
+          .setSql(proto.SQL.newBuilder().setQuery("SELECT 1")))
+      .build()
+    val request = createExecutePlanRequest(normalPlan)
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    // Verify the request passed through unchanged
+    assert(handler.receivedMessage == request)
+
+    // Verify no compressed size was set for non-compressed plans
+    assert(handler.capturedCompressedSize.isEmpty)
+  }
+
+  test("passes through FetchErrorDetailsRequest messages unchanged") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val fetchErrorRequest = proto.FetchErrorDetailsRequest
+      .newBuilder()
+      .setSessionId(testSessionId)
+      .setUserContext(testUserCtx)
+      .setErrorId("test-error-id")
+      .build()
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(fetchErrorRequest)
+
+    // Verify the message passed through unchanged
+    assert(handler.receivedMessage == fetchErrorRequest)
+  }
+
+  test("decompresses compressed AnalyzePlanRequest - Schema") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val compressedPlan = createCompressedPlan(s"select ${"Apache Spark" * 
10000} as value")
+    val request = proto.AnalyzePlanRequest
+      .newBuilder()
+      .setSessionId(testSessionId)
+      .setUserContext(testUserCtx)
+      
.setSchema(proto.AnalyzePlanRequest.Schema.newBuilder().setPlan(compressedPlan))
+      .build()
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    assert(handler.receivedMessage != null)
+    val decompressed = 
handler.receivedMessage.asInstanceOf[proto.AnalyzePlanRequest]
+    assert(!decompressed.getSchema.getPlan.hasCompressedOperation)
+    assert(decompressed.getSchema.getPlan.hasRoot)
+    assert(handler.capturedCompressedSize.isDefined)
+  }
+
+  test("decompresses SameSemantics with both plans compressed") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val plan1 = createCompressedPlan(s"select ${"Apache Spark" * 10000} as 
value")
+    val plan2 = createCompressedPlan("SELECT 1")
+    val request = proto.AnalyzePlanRequest
+      .newBuilder()
+      .setSessionId(testSessionId)
+      .setUserContext(testUserCtx)
+      .setSameSemantics(
+        proto.AnalyzePlanRequest.SameSemantics
+          .newBuilder()
+          .setTargetPlan(plan1)
+          .setOtherPlan(plan2))
+      .build()
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    assert(handler.receivedMessage != null)
+    val decompressed = 
handler.receivedMessage.asInstanceOf[proto.AnalyzePlanRequest]
+    assert(!decompressed.getSameSemantics.getTargetPlan.hasCompressedOperation)
+    assert(!decompressed.getSameSemantics.getOtherPlan.hasCompressedOperation)
+    assert(handler.capturedCompressedSize.isDefined)
+  }
+
+  test("passes through non-Plan AnalyzePlanRequest - SparkVersion") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val request = proto.AnalyzePlanRequest
+      .newBuilder()
+      .setSessionId(testSessionId)
+      .setUserContext(testUserCtx)
+      .setSparkVersion(proto.AnalyzePlanRequest.SparkVersion.newBuilder())
+      .build()
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    assert(handler.receivedMessage == request)
+    assert(handler.capturedCompressedSize.isEmpty)
+  }
+
+  test("handles AnalyzePlanRequest decompression errors") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val invalidPlan = proto.Plan
+      .newBuilder()
+      .setCompressedOperation(
+        proto.Plan.CompressedOperation
+          .newBuilder()
+          .setData(ByteString.copyFrom(Array[Byte](1, 2, 3, 4, 5)))
+          .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+          .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD))
+      .build()
+
+    val request = proto.AnalyzePlanRequest
+      .newBuilder()
+      .setSessionId(testSessionId)
+      .setUserContext(testUserCtx)
+      
.setSchema(proto.AnalyzePlanRequest.Schema.newBuilder().setPlan(invalidPlan))
+      .build()
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    assert(call.closedStatus != null)
+    assert(call.closedStatus.getCode != Status.OK.getCode)
+  }
+
+  test("handles decompression errors with enriched error information") {
+    val interceptor = new RequestDecompressionInterceptor()
+
+    // Create an invalid compressed plan (corrupted data)
+    val invalidCompressedPlan = proto.Plan
+      .newBuilder()
+      .setCompressedOperation(
+        proto.Plan.CompressedOperation
+          .newBuilder()
+          .setData(ByteString.copyFrom(Array[Byte](1, 2, 3, 4, 5))) // Invalid 
compressed data
+          .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+          .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+          .build())
+      .build()
+
+    val request = createExecutePlanRequest(invalidCompressedPlan)
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    // Verify that ServerCall.close() was called with an error status
+    assert(call.closedStatus != null, "Decompression error should close the 
call")
+    assert(call.closedStatus.getCode != Status.OK.getCode, "Status should 
indicate an error")
+    assert(call.closedTrailers != null, "Error trailers should be set")
+  }
+
+  test("sets compressed plan size in context for compressed plans") {
+    val interceptor = new RequestDecompressionInterceptor()
+    val compressedPlan = createCompressedPlan(s"select ${"Apache Spark" * 
10000} as value")
+    val request = createExecutePlanRequest(compressedPlan)
+
+    val call = new TestServerCall()
+    val handler = new TestHandler()
+    val headers = new Metadata()
+
+    val listener = interceptor.interceptCall(call, headers, handler)
+    listener.onMessage(request)
+
+    // Verify decompression succeeded
+    assert(handler.receivedMessage != null)
+
+    // Verify compressed plan size was set in context
+    assert(handler.capturedCompressedSize.isDefined)
+    val actualCompressedSize = 
compressedPlan.getCompressedOperation.getData.size().toLong
+    assert(handler.capturedCompressedSize.get == actualCompressedSize)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to