This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f93eff3d9062 [SPARK-53329][CONNECT] Improve exception handling when
adding artifacts
f93eff3d9062 is described below
commit f93eff3d9062aa24703624703cae49d68aed9be5
Author: Hendrik Huebner <[email protected]>
AuthorDate: Mon Sep 1 11:32:31 2025 +0800
[SPARK-53329][CONNECT] Improve exception handling when adding artifacts
### What changes were proposed in this pull request?
When a user sends multiple artifacts with the `addArtifacts` API, we
process each artifact one at a time on the server-side.
If the server detects the user attempting to modify an artifact (by
overwriting an existing artifact of the same path with a different byte
sequence), an exception is immediately thrown and artifact addition process is
terminated.
Instead, the operation should be idempotent and the server should try to
add as many artifacts as possible instead of returning early.
### Why are the changes needed?
As explained, if the server encounters an error while adding artifacts it
will return immediately. This can be a bit wasteful as the server discards all
other artifacts sent over the wire regardless of their own status. Thus, an
improvement can be made to process all artifacts, catch any exceptions and
rethrow them at the end.
### Does this PR introduce _any_ user-facing change?
This PR does not modify the existing API or the return codes. If the above
scenario is triggered, the only user facing change is that the server adds as
many artifacts as possible. Therefore it should be fully backwards compatible.
Additionally, if more than one artifact already existed, its exception is added
as a suppressed exception. Currently, these suppressed exceptions are not
serialized into the grpc object and sent over the wire, however.
### How was this patch tested?
Unit tests and local testing.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52073 from HendrikHuebner/improve-add-artifact-exceptions.
Lead-authored-by: Hendrik Huebner <[email protected]>
Co-authored-by: Hendrik Huebner <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/util/ArtifactUtils.scala | 16 ++++
.../service/SparkConnectAddArtifactsHandler.scala | 35 ++++++--
.../connect/service/AddArtifactsHandlerSuite.scala | 93 +++++++++++++++++++++-
.../spark/sql/artifact/ArtifactManager.scala | 50 +++++++-----
.../spark/sql/artifact/ArtifactManagerSuite.scala | 73 ++++++++++++++++-
5 files changed, 238 insertions(+), 29 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
index 8cd239b55cff..0fc14b4c8b46 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArtifactUtils.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.util
import java.nio.file.{Path, Paths}
+import org.apache.spark.SparkRuntimeException
+
object ArtifactUtils {
private[sql] def concatenatePaths(basePath: Path, otherPath: Path): Path = {
@@ -40,4 +42,18 @@ object ArtifactUtils {
private[sql] def concatenatePaths(basePath: Path, otherPath: String): Path =
{
concatenatePaths(basePath, Paths.get(otherPath))
}
+
+ /**
+ * Converts a sequence of exceptions into a single exception by adding all
but the first
+ * exceptions as suppressed exceptions to the first one.
+ * @param exceptions
+ * @return
+ */
+ private[sql] def mergeExceptionsWithSuppressed(
+ exceptions: Seq[SparkRuntimeException]): SparkRuntimeException = {
+ require(exceptions.nonEmpty)
+ val mainException = exceptions.head
+ exceptions.drop(1).foreach(mainException.addSuppressed)
+ mainException
+ }
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index 3ba79402e99e..becd7d855133 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -26,6 +26,7 @@ import scala.util.control.NonFatal
import com.google.common.io.CountingOutputStream
import io.grpc.stub.StreamObserver
+import org.apache.spark.SparkRuntimeException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest,
AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
@@ -112,19 +113,32 @@ class SparkConnectAddArtifactsHandler(val
responseObserver: StreamObserver[AddAr
* @return
*/
protected def flushStagedArtifacts(): Seq[ArtifactSummary] = {
+ val failedArtifactExceptions = mutable.ListBuffer[SparkRuntimeException]()
+
// Non-lazy transformation when using Buffer.
- stagedArtifacts.map { artifact =>
- // We do not store artifacts that fail the CRC. The failure is reported
in the artifact
- // summary and it is up to the client to decide whether to retry sending
the artifact.
- if (artifact.getCrcStatus.contains(true)) {
- if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix +
File.separator)) {
- holder.artifactManager.uploadArtifactToFs(artifact.path,
artifact.stagedPath)
- } else {
- addStagedArtifactToArtifactManager(artifact)
+ val summaries = stagedArtifacts.map { artifact =>
+ try {
+ // We do not store artifacts that fail the CRC. The failure is
reported in the artifact
+ // summary and it is up to the client to decide whether to retry
sending the artifact.
+ if (artifact.getCrcStatus.contains(true)) {
+ if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix +
File.separator)) {
+ holder.artifactManager.uploadArtifactToFs(artifact.path,
artifact.stagedPath)
+ } else {
+ addStagedArtifactToArtifactManager(artifact)
+ }
}
+ } catch {
+ case e: SparkRuntimeException if e.getCondition ==
"ARTIFACT_ALREADY_EXISTS" =>
+ failedArtifactExceptions += e
}
artifact.summary()
}.toSeq
+
+ if (failedArtifactExceptions.nonEmpty) {
+ throw
ArtifactUtils.mergeExceptionsWithSuppressed(failedArtifactExceptions.toSeq)
+ }
+
+ summaries
}
protected def cleanUpStagedArtifacts(): Unit =
Utils.deleteRecursively(stagingDir.toFile)
@@ -216,6 +230,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver:
StreamObserver[AddAr
private val fileOut = Files.newOutputStream(stagedPath)
private val countingOut = new CountingOutputStream(fileOut)
private val checksumOut = new CheckedOutputStream(countingOut, new CRC32)
+ private val overallChecksum = new CRC32()
private val builder = ArtifactSummary.newBuilder().setName(name)
private var artifactSummary: ArtifactSummary = _
@@ -227,6 +242,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver:
StreamObserver[AddAr
def getCrcStatus: Option[Boolean] = Option(isCrcSuccess)
+ def getCrc: Long = overallChecksum.getValue
+
def write(dataChunk: proto.AddArtifactsRequest.ArtifactChunk): Unit = {
try dataChunk.getData.writeTo(checksumOut)
catch {
@@ -234,6 +251,8 @@ class SparkConnectAddArtifactsHandler(val responseObserver:
StreamObserver[AddAr
close()
throw e
}
+
+ overallChecksum.update(dataChunk.getData.toByteArray)
updateCrc(checksumOut.getChecksum.getValue == dataChunk.getCrc)
checksumOut.getChecksum.reset()
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
index a158ca9fad8c..6cc5daadfddd 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
@@ -32,6 +32,7 @@ import io.grpc.StatusRuntimeException
import io.grpc.protobuf.StatusProto
import io.grpc.stub.StreamObserver
+import org.apache.spark.SparkRuntimeException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest,
AddArtifactsResponse}
import org.apache.spark.sql.connect.ResourceHelper
@@ -43,6 +44,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession
with ResourceHelper {
private val CHUNK_SIZE: Int = 32 * 1024
private val sessionId = UUID.randomUUID.toString()
+ private val sessionKey = SessionKey("c1", sessionId)
class DummyStreamObserver(p: Promise[AddArtifactsResponse])
extends StreamObserver[AddArtifactsResponse] {
@@ -51,17 +53,31 @@ class AddArtifactsHandlerSuite extends SharedSparkSession
with ResourceHelper {
override def onCompleted(): Unit = {}
}
- class TestAddArtifactsHandler(responseObserver:
StreamObserver[AddArtifactsResponse])
+ class TestAddArtifactsHandler(
+ responseObserver: StreamObserver[AddArtifactsResponse],
+ throwIfArtifactExists: Boolean = false)
extends SparkConnectAddArtifactsHandler(responseObserver) {
// Stop the staged artifacts from being automatically deleted
override protected def cleanUpStagedArtifacts(): Unit = {}
private val finalArtifacts = mutable.Buffer.empty[String]
+ private val artifactChecksums: mutable.Map[String, Long] =
mutable.Map.empty
// Record the artifacts that are sent out for final processing.
override protected def addStagedArtifactToArtifactManager(artifact:
StagedArtifact): Unit = {
+ // Throw if artifact already exists and has different checksum
+ // This mocks the behavior of ArtifactManager.addArtifact without
comparing the entire file
+ if (throwIfArtifactExists
+ && finalArtifacts.contains(artifact.name)
+ && artifact.getCrc != artifactChecksums(artifact.name)) {
+ throw new SparkRuntimeException(
+ "ARTIFACT_ALREADY_EXISTS",
+ Map("normalizedRemoteRelativePath" -> artifact.name))
+ }
+
finalArtifacts.append(artifact.name)
+ artifactChecksums += (artifact.name -> artifact.getCrc)
}
def getFinalArtifacts: Seq[String] = finalArtifacts.toSeq
@@ -418,4 +434,79 @@ class AddArtifactsHandlerSuite extends SharedSparkSession
with ResourceHelper {
}
}
+ def addSingleChunkArtifact(
+ handler: SparkConnectAddArtifactsHandler,
+ sessionKey: SessionKey,
+ name: String,
+ artifactPath: Path): Unit = {
+ val dataChunks = getDataChunks(artifactPath)
+ assert(dataChunks.size == 1)
+ val bytes = dataChunks.head
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId(sessionKey.userId)
+ .build()
+ val fileNameNoExtension = artifactPath.getFileName.toString.split('.').head
+ val singleChunkArtifact = proto.AddArtifactsRequest.SingleChunkArtifact
+ .newBuilder()
+ .setName(name)
+ .setData(
+ proto.AddArtifactsRequest.ArtifactChunk
+ .newBuilder()
+ .setData(bytes)
+ .setCrc(getCrcValues(crcPath.resolve(fileNameNoExtension +
".txt")).head)
+ .build())
+ .build()
+
+ val singleChunkArtifactRequest = AddArtifactsRequest
+ .newBuilder()
+ .setSessionId(sessionKey.sessionId)
+ .setUserContext(context)
+ .setBatch(
+
proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build())
+ .build()
+
+ handler.onNext(singleChunkArtifactRequest)
+ }
+
+ test("All artifacts are added, even if some fail") {
+ val promise = Promise[AddArtifactsResponse]()
+ val handler =
+ new TestAddArtifactsHandler(new DummyStreamObserver(promise),
throwIfArtifactExists = true)
+ try {
+ val name1 = "jars/dummy1.jar"
+ val name2 = "jars/dummy2.jar"
+ val name3 = "jars/dummy3.jar"
+
+ val artifactPath1 = inputFilePath.resolve("smallClassFile.class")
+ val artifactPath2 = inputFilePath.resolve("smallJar.jar")
+
+ assume(artifactPath1.toFile.exists)
+ addSingleChunkArtifact(handler, sessionKey, name1, artifactPath1)
+ addSingleChunkArtifact(handler, sessionKey, name3, artifactPath1)
+
+ val e = intercept[StatusRuntimeException] {
+ addSingleChunkArtifact(handler, sessionKey, name1, artifactPath2)
+ addSingleChunkArtifact(handler, sessionKey, name2, artifactPath1)
+ addSingleChunkArtifact(handler, sessionKey, name3, artifactPath2)
+ handler.onCompleted()
+ }
+
+ // Both artifacts should be added, despite exception
+ assert(handler.getFinalArtifacts.contains(name1))
+ assert(handler.getFinalArtifacts.contains(name2))
+ assert(handler.getFinalArtifacts.contains(name3))
+
+ assert(e.getStatus.getCode == Code.INTERNAL)
+ val statusProto = StatusProto.fromThrowable(e)
+ assert(statusProto.getDetailsCount == 1)
+ val details = statusProto.getDetails(0)
+ val info = details.unpack(classOf[ErrorInfo])
+
+ assert(e.getMessage.contains("ARTIFACT_ALREADY_EXISTS"))
+ assert(info.getMetadataMap().get("messageParameters").contains(name1))
+ } finally {
+ handler.forceCleanUp()
+ }
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index b0efd09d362a..de91e5e8a44b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -25,6 +25,7 @@ import java.nio.file.{CopyOption, Files, Path, Paths,
StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
@@ -266,28 +267,39 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
* they are from a permanent location.
*/
private[sql] def addLocalArtifacts(artifacts: Seq[Artifact]): Unit = {
+ val failedArtifactExceptions = ListBuffer[SparkRuntimeException]()
+
artifacts.foreach { artifact =>
- artifact.storage match {
- case d: Artifact.LocalFile =>
- addArtifact(
- artifact.path,
- d.path,
- fragment = None,
- deleteStagedFile = false)
- case d: Artifact.InMemory =>
- val tempDir = Utils.createTempDir().toPath
- val tempFile = tempDir.resolve(artifact.path.getFileName)
- val outStream = Files.newOutputStream(tempFile)
- Utils.tryWithSafeFinallyAndFailureCallbacks {
- d.stream.transferTo(outStream)
- addArtifact(artifact.path, tempFile, fragment = None)
- }(finallyBlock = {
- outStream.close()
- })
- case _ =>
- throw SparkException.internalError(s"Unsupported artifact storage:
${artifact.storage}")
+ try {
+ artifact.storage match {
+ case d: Artifact.LocalFile =>
+ addArtifact(
+ artifact.path,
+ d.path,
+ fragment = None,
+ deleteStagedFile = false)
+ case d: Artifact.InMemory =>
+ val tempDir = Utils.createTempDir().toPath
+ val tempFile = tempDir.resolve(artifact.path.getFileName)
+ val outStream = Files.newOutputStream(tempFile)
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ d.stream.transferTo(outStream)
+ addArtifact(artifact.path, tempFile, fragment = None)
+ }(finallyBlock = {
+ outStream.close()
+ })
+ case _ =>
+ throw SparkException.internalError(s"Unsupported artifact storage:
${artifact.storage}")
+ }
+ } catch {
+ case e: SparkRuntimeException if e.getCondition ==
"ARTIFACT_ALREADY_EXISTS" =>
+ failedArtifactExceptions += e
}
}
+
+ if (failedArtifactExceptions.nonEmpty) {
+ throw
ArtifactUtils.mergeExceptionsWithSuppressed(failedArtifactExceptions.toSeq)
+ }
}
def classloader: ClassLoader = synchronized {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
index 765d879bdd8d..f4a4ab012c2e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
@@ -20,8 +20,9 @@ import java.io.File
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Path, Paths}
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException}
import org.apache.spark.metrics.source.CodegenMetrics
+import org.apache.spark.sql.Artifact
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
@@ -346,6 +347,76 @@ class ArtifactManagerSuite extends SharedSparkSession {
}
}
+ test("Add multiple artifacts to local session and check if all are added
despite exception") {
+ val copyDir = Utils.createTempDir().toPath
+ Utils.copyDirectory(artifactPath.toFile, copyDir.toFile)
+
+ val artifact1Path = "my/custom/pkg/artifact1.jar"
+ val artifact2Path = "my/custom/pkg/artifact2.jar"
+ val targetPath = Paths.get(artifact1Path)
+ val targetPath2 = Paths.get(artifact2Path)
+
+ val classPath1 = copyDir.resolve("Hello.class")
+ val classPath2 = copyDir.resolve("udf_noA.jar")
+ assume(artifactPath.resolve("Hello.class").toFile.exists)
+ assume(artifactPath.resolve("smallClassFile.class").toFile.exists)
+
+ val artifact1 = Artifact.newArtifactFromExtension(
+ targetPath.getFileName.toString,
+ targetPath,
+ new Artifact.LocalFile(Paths.get(classPath1.toString)))
+
+ val alreadyExistingArtifact = Artifact.newArtifactFromExtension(
+ targetPath2.getFileName.toString,
+ targetPath,
+ new Artifact.LocalFile(Paths.get(classPath2.toString)))
+
+ val artifact2 = Artifact.newArtifactFromExtension(
+ targetPath2.getFileName.toString,
+ targetPath2,
+ new Artifact.LocalFile(Paths.get(classPath2.toString)))
+
+ spark.artifactManager.addLocalArtifacts(Seq(artifact1))
+
+ val ex = intercept[SparkRuntimeException] {
+ spark.artifactManager.addLocalArtifacts(
+ Seq(alreadyExistingArtifact, artifact2, alreadyExistingArtifact))
+ }
+
+ checkError(
+ exception = ex,
+ condition = "ARTIFACT_ALREADY_EXISTS",
+ parameters = Map("normalizedRemoteRelativePath" ->
s"jars/${targetPath.toString}"))
+
+ assert(ex.getSuppressed.length == 1)
+ assert(ex.getSuppressed.head.isInstanceOf[SparkRuntimeException])
+ val suppressed = ex.getSuppressed.head.asInstanceOf[SparkRuntimeException]
+
+ checkError(
+ exception = suppressed,
+ condition = "ARTIFACT_ALREADY_EXISTS",
+ parameters = Map("normalizedRemoteRelativePath" ->
s"jars/${targetPath.toString}"))
+
+ // Artifact1 should have been added
+ val expectedFile1 = ArtifactManager.artifactRootDirectory
+ .resolve(s"$sessionUUID/jars/$artifact1Path")
+ .toFile
+ assert(expectedFile1.exists())
+
+ // Artifact2 should have been added despite exception
+ val expectedFile2 = ArtifactManager.artifactRootDirectory
+ .resolve(s"$sessionUUID/jars/$artifact2Path")
+ .toFile
+ assert(expectedFile2.exists())
+
+ // Cleanup
+ artifactManager.cleanUpResourcesForTesting()
+ val sessionDir =
ArtifactManager.artifactRootDirectory.resolve(sessionUUID).toFile
+
+ assert(!expectedFile1.exists())
+ assert(!sessionDir.exists())
+ }
+
test("Added artifact can be loaded by the current SparkSession") {
val path = artifactPath.resolve("IntSumUdf.class")
assume(path.toFile.exists)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]