This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4863be5632f [SPARK-45207][SQL][CONNECT] Implement Error Enrichment for
Scala Client
4863be5632f is described below
commit 4863be5632f3165a5699a525235ea118c1e1f7eb
Author: Yihong He <[email protected]>
AuthorDate: Mon Sep 25 09:35:33 2023 +0900
[SPARK-45207][SQL][CONNECT] Implement Error Enrichment for Scala Client
### What changes were proposed in this pull request?
- Implemented the reconstruction of the complete exception (un-truncated
error messages, cause exceptions, server-side stacktrace) based on the
responses of FetchErrorDetails RPC.
### Why are the changes needed?
- Cause exceptions play an important role in the current control flow, such
as in StreamingQueryException. They are also valuable for debugging.
- Un-truncated error message is useful for debugging
- Providing server-side stack traces aids in effectively diagnosing
server-related issues.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"`
- `build/sbt "connect-client-jvm/testOnly *ClientStreamingQuerySuite"`
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42987 from heyihong/SPARK-45207.
Authored-by: Yihong He <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 59 ++++++-
.../sql/streaming/ClientStreamingQuerySuite.scala | 41 ++++-
.../client/CustomSparkConnectBlockingStub.scala | 44 ++++-
.../connect/client/GrpcExceptionConverter.scala | 192 +++++++++++++++++----
4 files changed, 292 insertions(+), 44 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 21892542eab..ec9b1698a4e 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, PrintStream}
import java.nio.file.Files
+import java.time.DateTimeException
import java.util.Properties
import scala.collection.JavaConverters._
@@ -29,7 +30,7 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalactic.TolerantNumerics
import org.scalatest.PrivateMethodTester
-import org.apache.spark.{SparkArithmeticException, SparkException}
+import org.apache.spark.{SparkArithmeticException, SparkException,
SparkUpgradeException}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException,
TempTableAlreadyExistsException}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
@@ -44,6 +45,62 @@ import org.apache.spark.sql.types._
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with
PrivateMethodTester {
+ for (enrichErrorEnabled <- Seq(false, true)) {
+ test(s"cause exception - ${enrichErrorEnabled}") {
+ withSQLConf("spark.sql.connect.enrichError.enabled" ->
enrichErrorEnabled.toString) {
+ val ex = intercept[SparkUpgradeException] {
+ spark
+ .sql("""
+ |select from_json(
+ | '{"d": "02-29"}',
+ | 'd date',
+ | map('dateFormat', 'MM-dd'))
+ |""".stripMargin)
+ .collect()
+ }
+ if (enrichErrorEnabled) {
+ assert(ex.getCause.isInstanceOf[DateTimeException])
+ } else {
+ assert(ex.getCause == null)
+ }
+ }
+ }
+ }
+
+ test(s"throw SparkException with large cause exception") {
+ withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") {
+ val session = spark
+ import session.implicits._
+
+ val throwException =
+ udf((_: String) => throw new SparkException("test" * 10000))
+
+ val ex = intercept[SparkException] {
+ Seq("1").toDS.withColumn("udf_val", throwException($"value")).collect()
+ }
+
+ assert(ex.getCause.isInstanceOf[SparkException])
+ assert(ex.getCause.getMessage.contains("test" * 10000))
+ }
+ }
+
+ for (isServerStackTraceEnabled <- Seq(false, true)) {
+ test(s"server-side stack trace is set in exceptions -
${isServerStackTraceEnabled}") {
+ withSQLConf(
+ "spark.sql.connect.serverStacktrace.enabled" ->
isServerStackTraceEnabled.toString,
+ "spark.sql.pyspark.jvmStacktrace.enabled" -> "false") {
+ val ex = intercept[AnalysisException] {
+ spark.sql("select x").collect()
+ }
+ assert(
+ ex.getStackTrace
+
.find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis"))
+ .isDefined
+ == isServerStackTraceEnabled)
+ }
+ }
+ }
+
test("throw SparkArithmeticException") {
withSQLConf("spark.sql.ansi.enabled" -> "true") {
intercept[SparkArithmeticException] {
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index dc4d441ec30..5d281cfbfeb 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -27,11 +27,11 @@ import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._
+import org.apache.spark.SparkException
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
-import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.functions.window
+import org.apache.spark.sql.functions.{col, udf, window}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent,
QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.util.SparkFileUtils
@@ -175,6 +175,43 @@ class ClientStreamingQuerySuite extends QueryTest with
SQLHelper with Logging {
}
}
+ test("throw exception in streaming") {
+ // Disable spark.sql.pyspark.jvmStacktrace.enabled to avoid hitting the
+ // netty header limit.
+ withSQLConf("spark.sql.pyspark.jvmStacktrace.enabled" -> "false") {
+ val session = spark
+ import session.implicits._
+
+ val checkForTwo = udf((value: Int) => {
+ if (value == 2) {
+ throw new RuntimeException("Number 2 encountered!")
+ }
+ value
+ })
+
+ val query = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "1")
+ .load()
+ .select(checkForTwo($"value").as("checkedValue"))
+ .writeStream
+ .outputMode("append")
+ .format("console")
+ .start()
+
+ val exception = intercept[SparkException] {
+ query.awaitTermination()
+ }
+
+ assert(exception.getCause.isInstanceOf[SparkException])
+ assert(exception.getCause.getCause.isInstanceOf[SparkException])
+ assert(exception.getCause.getCause.getCause.isInstanceOf[SparkException])
+ assert(
+ exception.getCause.getCause.getCause.getMessage
+ .contains("java.lang.RuntimeException: Number 2 encountered!"))
+ }
+ }
+
test("foreach Row") {
val writer = new TestForeachWriter[Row]
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 80edcfa8be1..f02704b2a02 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -27,11 +27,21 @@ private[connect] class CustomSparkConnectBlockingStub(
retryPolicy: GrpcRetryHandler.RetryPolicy) {
private val stub = SparkConnectServiceGrpc.newBlockingStub(channel)
+
private val retryHandler = new GrpcRetryHandler(retryPolicy)
+ // GrpcExceptionConverter with a GRPC stub for fetching error details from
server.
+ private val grpcExceptionConverter = new GrpcExceptionConverter(stub)
+
def executePlan(request: ExecutePlanRequest):
CloseableIterator[ExecutePlanResponse] = {
- GrpcExceptionConverter.convert {
- GrpcExceptionConverter.convertIterator[ExecutePlanResponse](
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
+ grpcExceptionConverter.convertIterator[ExecutePlanResponse](
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType,
retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse](
request,
r => CloseableIterator(stub.executePlan(r).asScala)))
@@ -40,15 +50,24 @@ private[connect] class CustomSparkConnectBlockingStub(
def executePlanReattachable(
request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = {
- GrpcExceptionConverter.convert {
- GrpcExceptionConverter.convertIterator[ExecutePlanResponse](
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
+ grpcExceptionConverter.convertIterator[ExecutePlanResponse](
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType,
// Don't use retryHandler - own retry handling is inside.
new ExecutePlanResponseReattachableIterator(request, channel,
retryPolicy))
}
}
def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = {
- GrpcExceptionConverter.convert {
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
retryHandler.retry {
stub.analyzePlan(request)
}
@@ -56,7 +75,10 @@ private[connect] class CustomSparkConnectBlockingStub(
}
def config(request: ConfigRequest): ConfigResponse = {
- GrpcExceptionConverter.convert {
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
retryHandler.retry {
stub.config(request)
}
@@ -64,7 +86,10 @@ private[connect] class CustomSparkConnectBlockingStub(
}
def interrupt(request: InterruptRequest): InterruptResponse = {
- GrpcExceptionConverter.convert {
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
retryHandler.retry {
stub.interrupt(request)
}
@@ -72,7 +97,10 @@ private[connect] class CustomSparkConnectBlockingStub(
}
def artifactStatus(request: ArtifactStatusesRequest):
ArtifactStatusesResponse = {
- GrpcExceptionConverter.convert {
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
retryHandler.retry {
stub.artifactStatus(request)
}
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index fe9f6dc2b4a..edbc434ef96 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -24,49 +24,145 @@ import scala.reflect.ClassTag
import com.google.rpc.ErrorInfo
import io.grpc.StatusRuntimeException
import io.grpc.protobuf.StatusProto
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods
import org.apache.spark.{SparkArithmeticException,
SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException,
SparkIllegalArgumentException, SparkNumberFormatException,
SparkRuntimeException, SparkUnsupportedOperationException,
SparkUpgradeException}
+import org.apache.spark.connect.proto.{FetchErrorDetailsRequest,
FetchErrorDetailsResponse, UserContext}
+import
org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException,
TempTableAlreadyExistsException}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
-import org.apache.spark.util.JsonUtils
-private[client] object GrpcExceptionConverter extends JsonUtils {
- def convert[T](f: => T): T = {
+/**
+ * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions
into Spark exceptions.
+ * It does so by utilizing the ErrorInfo defined in error_details.proto and
making an additional
+ * FetchErrorDetails RPC call to retrieve the full error message and
optionally the server-side
+ * stacktrace.
+ *
+ * If the FetchErrorDetails RPC call succeeds, the exceptions will be
constructed based on the
+ * response. If the RPC call fails, the exception will be constructed based on
the ErrorInfo. If
+ * the ErrorInfo is missing, the exception will be constructed based on the
StatusRuntimeException
+ * itself.
+ */
+private[client] class GrpcExceptionConverter(grpcStub:
SparkConnectServiceBlockingStub)
+ extends Logging {
+ import GrpcExceptionConverter._
+
+ def convert[T](sessionId: String, userContext: UserContext, clientType:
String)(f: => T): T = {
try {
f
} catch {
case e: StatusRuntimeException =>
- throw toThrowable(e)
+ throw toThrowable(e, sessionId, userContext, clientType)
}
}
- def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = {
+ def convertIterator[T](
+ sessionId: String,
+ userContext: UserContext,
+ clientType: String,
+ iter: CloseableIterator[T]): CloseableIterator[T] = {
new WrappedCloseableIterator[T] {
override def innerIterator: Iterator[T] = iter
override def hasNext: Boolean = {
- convert {
+ convert(sessionId, userContext, clientType) {
iter.hasNext
}
}
override def next(): T = {
- convert {
+ convert(sessionId, userContext, clientType) {
iter.next()
}
}
override def close(): Unit = {
- convert {
+ convert(sessionId, userContext, clientType) {
iter.close()
}
}
}
}
+ /**
+ * Fetches enriched errors with full exception message and optionally
stacktrace by issuing an
+ * additional RPC call to fetch error details. The RPC call is best-effort
at-most-once.
+ */
+ private def fetchEnrichedError(
+ info: ErrorInfo,
+ sessionId: String,
+ userContext: UserContext,
+ clientType: String): Option[Throwable] = {
+ val errorId = info.getMetadataOrDefault("errorId", null)
+ if (errorId == null) {
+ logWarning("Unable to fetch enriched error since errorId is missing")
+ return None
+ }
+
+ try {
+ val errorDetailsResponse = grpcStub.fetchErrorDetails(
+ FetchErrorDetailsRequest
+ .newBuilder()
+ .setSessionId(sessionId)
+ .setErrorId(errorId)
+ .setUserContext(userContext)
+ .setClientType(clientType)
+ .build())
+
+ if (!errorDetailsResponse.hasRootErrorIdx) {
+ logWarning("Unable to fetch enriched error since error is not found")
+ return None
+ }
+
+ Some(
+ errorsToThrowable(
+ errorDetailsResponse.getRootErrorIdx,
+ errorDetailsResponse.getErrorsList.asScala.toSeq))
+ } catch {
+ case e: StatusRuntimeException =>
+ logWarning("Unable to fetch enriched error", e)
+ None
+ }
+ }
+
+ private def toThrowable(
+ ex: StatusRuntimeException,
+ sessionId: String,
+ userContext: UserContext,
+ clientType: String): Throwable = {
+ val status = StatusProto.fromThrowable(ex)
+
+ // Extract the ErrorInfo from the StatusProto, if present.
+ val errorInfoOpt = status.getDetailsList.asScala
+ .find(_.is(classOf[ErrorInfo]))
+ .map(_.unpack(classOf[ErrorInfo]))
+
+ if (errorInfoOpt.isDefined) {
+ // If ErrorInfo is found, try to fetch enriched error details by an
additional RPC.
+ val enrichedErrorOpt =
+ fetchEnrichedError(errorInfoOpt.get, sessionId, userContext,
clientType)
+ if (enrichedErrorOpt.isDefined) {
+ return enrichedErrorOpt.get
+ }
+
+ // If fetching enriched error details fails, convert ErrorInfo to a
Throwable.
+ // Unlike enriched errors above, the message from status may be
truncated,
+ // and no cause exceptions or server-side stack traces will be
reconstructed.
+ return errorInfoToThrowable(errorInfoOpt.get, status.getMessage)
+ }
+
+ // If no ErrorInfo is found, create a SparkException based on the
StatusRuntimeException.
+ new SparkException(ex.toString, ex.getCause)
+ }
+}
+
+private object GrpcExceptionConverter {
+
private def errorConstructor[T <: Throwable: ClassTag](
throwableCtr: (String, Option[Throwable]) => T)
: (String, (String, Option[Throwable]) => Throwable) = {
@@ -93,33 +189,63 @@ private[client] object GrpcExceptionConverter extends
JsonUtils {
new SparkArrayIndexOutOfBoundsException(message)),
errorConstructor[DateTimeException]((message, _) => new
SparkDateTimeException(message)),
errorConstructor((message, cause) => new SparkRuntimeException(message,
cause)),
- errorConstructor((message, cause) => new SparkUpgradeException(message,
cause)))
-
- private def errorInfoToThrowable(info: ErrorInfo, message: String):
Option[Throwable] = {
- val classes =
- mapper.readValue(info.getMetadataOrDefault("classes", "[]"),
classOf[Array[String]])
+ errorConstructor((message, cause) => new SparkUpgradeException(message,
cause)),
+ errorConstructor((message, cause) => new SparkException(message,
cause.orNull)))
+
+ /**
+ * errorsToThrowable reconstructs the exception based on a list of protobuf
messages
+ * FetchErrorDetailsResponse.Error with un-truncated error messages and
server-side stacktrace
+ * (if set).
+ */
+ private def errorsToThrowable(
+ errorIdx: Int,
+ errors: Seq[FetchErrorDetailsResponse.Error]): Throwable = {
+
+ val error = errors(errorIdx)
+
+ val classHierarchy = error.getErrorTypeHierarchyList.asScala
+
+ val constructor =
+ classHierarchy
+ .flatMap(errorFactory.get)
+ .headOption
+ .getOrElse((message: String, cause: Option[Throwable]) =>
+ new SparkException(s"${classHierarchy.head}: ${message}",
cause.orNull))
+
+ val causeOpt =
+ if (error.hasCauseIdx) Some(errorsToThrowable(error.getCauseIdx,
errors)) else None
+
+ val exception = constructor(error.getMessage, causeOpt)
+
+ if (!error.getStackTraceList.isEmpty) {
+ exception.setStackTrace(error.getStackTraceList.asScala.toArray.map {
stackTraceElement =>
+ new StackTraceElement(
+ stackTraceElement.getDeclaringClass,
+ stackTraceElement.getMethodName,
+ stackTraceElement.getFileName,
+ stackTraceElement.getLineNumber)
+ })
+ }
- classes
- .find(errorFactory.contains)
- .map { cls =>
- val constructor = errorFactory.get(cls).get
- constructor(message, None)
- }
+ exception
}
- private def toThrowable(ex: StatusRuntimeException): Throwable = {
- val status = StatusProto.fromThrowable(ex)
-
- val fallbackEx = new SparkException(ex.toString, ex.getCause)
-
- val errorInfoOpt = status.getDetailsList.asScala
- .find(_.is(classOf[ErrorInfo]))
-
- if (errorInfoOpt.isEmpty) {
- return fallbackEx
- }
-
- errorInfoToThrowable(errorInfoOpt.get.unpack(classOf[ErrorInfo]),
status.getMessage)
- .getOrElse(fallbackEx)
+ /**
+ * errorInfoToThrowable reconstructs the exception based on the error
classes hierarchy and the
+ * truncated error message.
+ */
+ private def errorInfoToThrowable(info: ErrorInfo, message: String):
Throwable = {
+ implicit val formats = DefaultFormats
+ val classes =
+ JsonMethods.parse(info.getMetadataOrDefault("classes",
"[]")).extract[Array[String]]
+
+ errorsToThrowable(
+ 0,
+ Seq(
+ FetchErrorDetailsResponse.Error
+ .newBuilder()
+ .setMessage(message)
+ .addAllErrorTypeHierarchy(classes.toIterable.asJava)
+ .build()))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]