This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9d031ba8c99 [SPARK-44078][CONNECT][CORE] Add support for
classloader/resource isolation
9d031ba8c99 is described below
commit 9d031ba8c995286e5f8892764e5108aa60f49238
Author: vicennial <[email protected]>
AuthorDate: Wed Jun 21 20:58:19 2023 -0400
[SPARK-44078][CONNECT][CORE] Add support for classloader/resource isolation
### What changes were proposed in this pull request?
This PR adds a `JobArtifactSet` which holds the jars/files/archives
relevant to a particular Spark Job. Using this "set", we are able to support
specifying visible/available resources for a job based on, for example, the
SparkSession that the job belongs to.
With resource specification support, we are further able to extend this to
support classloader/resource isolation on the executors. The executors would
use the `uuid` from the `JobArtifactSet` to either create or obtain from a
cache the
[IsolatedSessionState](https://github.com/apache/spark/pull/41625/files#diff-d7a989c491f3cb77cca02c701496a9e2a3443f70af73b0d1ab0899239f3a789dR57)
which holds the "state" (i.e classloaders, files, jars, archives etc) for that
particular `uuid`.
Currently, the code will default to copying over resources from the
`SparkContext` (the current/default behaviour) to avoid any behaviour changes.
A follow-up PR would use this mechanism in Spark Connect to isolate resources
among Spark Connect sessions.
### Why are the changes needed?
A current limitation of Scala UDFs is that a Spark cluster would only be
able to support a single REPL at a time due to the fact that classloaders of
different Spark Sessions (and therefore, Spark Connect sessions) aren't
isolated from each other. Without isolation, REPL-generated class files and
user-added JARs may conflict if there are multiple users of the cluster.
Thus, we need a mechanism to support isolated sessions (i.e isolated
resources/classloader) so that each REPL user does not conflict with other
users on the same cluster.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests + new suite `JobArtifactSetSuite`.
Closes #41625 from vicennial/SPARK-44078.
Authored-by: vicennial <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../scala/org/apache/spark/JobArtifactSet.scala | 123 +++++++++++++++++++++
.../scala/org/apache/spark/executor/Executor.scala | 120 +++++++++++++-------
.../org/apache/spark/scheduler/ActiveJob.scala | 3 +
.../org/apache/spark/scheduler/DAGScheduler.scala | 37 ++++---
.../apache/spark/scheduler/DAGSchedulerEvent.scala | 2 +
.../org/apache/spark/scheduler/ResultTask.scala | 5 +-
.../apache/spark/scheduler/ShuffleMapTask.scala | 9 +-
.../scala/org/apache/spark/scheduler/Task.scala | 2 +
.../apache/spark/scheduler/TaskDescription.scala | 61 ++++++----
.../apache/spark/scheduler/TaskSetManager.scala | 9 +-
.../org/apache/spark/JobArtifactSetSuite.scala | 87 +++++++++++++++
.../CoarseGrainedExecutorBackendSuite.scala | 7 +-
.../org/apache/spark/executor/ExecutorSuite.scala | 21 ++--
.../CoarseGrainedSchedulerBackendSuite.scala | 12 +-
.../apache/spark/scheduler/DAGSchedulerSuite.scala | 9 +-
.../org/apache/spark/scheduler/FakeTask.scala | 8 +-
.../spark/scheduler/NotSerializableFakeTask.scala | 4 +-
.../apache/spark/scheduler/TaskContextSuite.scala | 26 +++--
.../spark/scheduler/TaskDescriptionSuite.scala | 18 +--
.../spark/scheduler/TaskSchedulerImplSuite.scala | 4 +-
.../spark/scheduler/TaskSetManagerSuite.scala | 12 +-
.../MesosFineGrainedSchedulerBackendSuite.scala | 10 +-
22 files changed, 436 insertions(+), 153 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala
b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala
new file mode 100644
index 00000000000..d87c25c0b7c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.Serializable
+
+/**
+ * Artifact set for a job.
+ * This class is used to store session (i.e `SparkSession`) specific
resources/artifacts.
+ *
+ * When Spark Connect is used, this job-set points towards session-specific
jars and class files.
+ * Note that Spark Connect is not a requirement for using this class.
+ *
+ * @param uuid An optional UUID for this session. If unset, a default session
will be used.
+ * @param replClassDirUri An optional custom URI to point towards class files.
+ * @param jars Jars belonging to this session.
+ * @param files Files belonging to this session.
+ * @param archives Archives belonging to this session.
+ */
+class JobArtifactSet(
+ val uuid: Option[String],
+ val replClassDirUri: Option[String],
+ val jars: Map[String, Long],
+ val files: Map[String, Long],
+ val archives: Map[String, Long]) extends Serializable {
+ def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f)
+
+ override def hashCode(): Int = {
+ Seq(uuid, replClassDirUri, jars.toSeq, files.toSeq,
archives.toSeq).hashCode()
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: JobArtifactSet =>
+ this.getClass == that.getClass && this.uuid == that.uuid &&
+ this.replClassDirUri == that.replClassDirUri && this.jars.toSeq ==
that.jars.toSeq &&
+ this.files.toSeq == that.files.toSeq && this.archives.toSeq ==
that.archives.toSeq
+ }
+ }
+
+}
+
+object JobArtifactSet {
+
+ private[this] val current = new ThreadLocal[Option[JobArtifactSet]] {
+ override def initialValue(): Option[JobArtifactSet] = None
+ }
+
+ /**
+ * When Spark Connect isn't used, we default back to the shared resources.
+ * @param sc The active [[SparkContext]]
+ * @return A [[JobArtifactSet]] containing a copy of the jars/files/archives
from the underlying
+ * [[SparkContext]] `sc`.
+ */
+ def apply(sc: SparkContext): JobArtifactSet = {
+ new JobArtifactSet(
+ uuid = None,
+ replClassDirUri = sc.conf.getOption("spark.repl.class.uri"),
+ jars = sc.addedJars.toMap,
+ files = sc.addedFiles.toMap,
+ archives = sc.addedArchives.toMap)
+ }
+
+ /**
+ * Empty artifact set for use in tests.
+ */
+ private[spark] def apply(): JobArtifactSet = {
+ new JobArtifactSet(
+ None,
+ None,
+ Map.empty,
+ Map.empty,
+ Map.empty)
+ }
+
+ /**
+ * Used for testing. Returns artifacts from [[SparkContext]] if one exists
or otherwise, an
+ * empty set.
+ */
+ private[spark] def defaultArtifactSet(): JobArtifactSet = {
+ SparkContext.getActive.map(sc =>
JobArtifactSet(sc)).getOrElse(JobArtifactSet())
+ }
+
+ /**
+ * Execute a block of code with the currently active [[JobArtifactSet]].
+ * @param active
+ * @param block
+ * @tparam T
+ */
+ def withActive[T](active: JobArtifactSet)(block: => T): T = {
+ val old = current.get()
+ current.set(Option(active))
+ try block finally {
+ current.set(old)
+ }
+ }
+
+ /**
+ * Optionally returns the active [[JobArtifactSet]].
+ */
+ def active: Option[JobArtifactSet] = current.get()
+
+ /**
+ * Return the active [[JobArtifactSet]] or creates the default set using the
[[SparkContext]].
+ * @param sc
+ */
+ def getActiveOrDefault(sc: SparkContext): JobArtifactSet =
active.getOrElse(JobArtifactSet(sc))
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index ed3e8626d6d..a38a8efcd76 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -31,10 +31,11 @@ import javax.ws.rs.core.UriBuilder
import scala.collection.JavaConverters._
import scala.collection.immutable
-import scala.collection.mutable.{ArrayBuffer, HashMap, Map, WrappedArray}
+import scala.collection.mutable.{ArrayBuffer, HashMap, WrappedArray}
import scala.concurrent.duration._
import scala.util.control.NonFatal
+import com.google.common.cache.CacheBuilder
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.slf4j.MDC
@@ -53,6 +54,14 @@ import org.apache.spark.shuffle.{FetchFailedException,
ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
+private[spark] class IsolatedSessionState(
+ val sessionUUID: String,
+ val urlClassLoader: MutableURLClassLoader,
+ var replClassLoader: ClassLoader,
+ val currentFiles: HashMap[String, Long],
+ val currentJars: HashMap[String, Long],
+ val currentArchives: HashMap[String, Long])
+
/**
* Spark executor, backed by a threadpool to run tasks.
*
@@ -76,11 +85,6 @@ private[spark] class Executor(
val stopHookReference = ShutdownHookManager.addShutdownHook(
() => stop()
)
- // Application dependencies (added through SparkContext) that we've fetched
so far on this node.
- // Each map holds the master's timestamp for the version of that file or JAR
we got.
- private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
- private val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
- private val currentArchives: HashMap[String, Long] = new HashMap[String,
Long]()
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@@ -160,16 +164,34 @@ private[spark] class Executor(
private val killOnFatalErrorDepth =
conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH)
- // Create our ClassLoader
- // do this after SparkEnv creation so can access the SecurityManager
- private val urlClassLoader = createClassLoader()
- private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ private val systemLoader = Utils.getContextOrSparkClassLoader
+
+ private def newSessionState(
+ sessionUUID: String,
+ classUri: Option[String]): IsolatedSessionState = {
+ val currentFiles = new HashMap[String, Long]
+ val currentJars = new HashMap[String, Long]
+ val currentArchives = new HashMap[String, Long]
+ val urlClassLoader = createClassLoader(currentJars)
+ val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader, classUri)
+ new IsolatedSessionState(
+ sessionUUID, urlClassLoader, replClassLoader, currentFiles, currentJars,
currentArchives)
+ }
+
+ // Classloader isolation
+ // The default isolation group
+ val defaultSessionState = newSessionState("default", None)
+
+ val isolatedSessionCache = CacheBuilder.newBuilder()
+ .maximumSize(100)
+ .expireAfterAccess(5, TimeUnit.MINUTES)
+ .build[String, IsolatedSessionState]
// Set the classloader for serializer
- env.serializer.setDefaultClassLoader(replClassLoader)
+ env.serializer.setDefaultClassLoader(defaultSessionState.replClassLoader)
// SPARK-21928. SerializerManager's internal instance of Kryo might get
used in netty threads
// for fetching remote cached RDD blocks, so need to make sure it uses the
right classloader too.
- env.serializerManager.setDefaultClassLoader(replClassLoader)
+
env.serializerManager.setDefaultClassLoader(defaultSessionState.replClassLoader)
// Max size of direct result. If task result is bigger than this, we use the
block manager
// to send the result back. This is guaranteed to be smaller than array
bytes limit (2GB)
@@ -273,17 +295,18 @@ private[spark] class Executor(
private val Seq(initialUserJars, initialUserFiles, initialUserArchives) =
Seq("jar", "file", "archive").map { key =>
conf.getOption(s"spark.app.initial.$key.urls").map { urls =>
- Map(urls.split(",").map(url => (url, appStartTime)): _*)
- }.getOrElse(Map.empty)
+ immutable.Map(urls.split(",").map(url => (url, appStartTime)): _*)
+ }.getOrElse(immutable.Map.empty)
}
- updateDependencies(initialUserFiles, initialUserJars, initialUserArchives)
+ updateDependencies(initialUserFiles, initialUserJars, initialUserArchives,
defaultSessionState)
// Plugins need to load using a class loader that includes the executor's
user classpath.
// Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175).
- private val plugins: Option[PluginContainer] =
Utils.withContextClassLoader(replClassLoader) {
- PluginContainer(env, resources.asJava)
- }
+ private val plugins: Option[PluginContainer] =
+ Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
+ PluginContainer(env, resources.asJava)
+ }
metricsPoller.start()
@@ -381,9 +404,9 @@ private[spark] class Executor(
if (killMarkCleanupService != null) {
killMarkCleanupService.shutdown()
}
- if (replClassLoader != null && plugins != null) {
+ if (defaultSessionState != null && plugins != null) {
// Notify plugins that executor is shutting down so they can terminate
cleanly
- Utils.withContextClassLoader(replClassLoader) {
+ Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
plugins.foreach(_.shutdown())
}
}
@@ -485,6 +508,16 @@ private[spark] class Executor(
}
override def run(): Unit = {
+
+ // Classloader isolation
+ val isolatedSessionUUID: Option[String] = taskDescription.artifacts.uuid
+ val isolatedSession = isolatedSessionUUID match {
+ case Some(uuid) => isolatedSessionCache.get(
+ uuid,
+ () => newSessionState(uuid,
taskDescription.artifacts.replClassDirUri))
+ case _ => defaultSessionState
+ }
+
setMDCForTask(taskName, mdcProperties)
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
@@ -494,7 +527,7 @@ private[spark] class Executor(
val deserializeStartCpuTime = if
(threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
- Thread.currentThread.setContextClassLoader(replClassLoader)
+
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
@@ -509,7 +542,10 @@ private[spark] class Executor(
Executor.taskDeserializationProps.set(taskDescription.properties)
updateDependencies(
- taskDescription.addedFiles, taskDescription.addedJars,
taskDescription.addedArchives)
+ taskDescription.artifacts.files,
+ taskDescription.artifacts.jars,
+ taskDescription.artifacts.archives,
+ isolatedSession)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask,
Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
@@ -961,15 +997,13 @@ private[spark] class Executor(
* Create a ClassLoader for use in tasks, adding any JARs specified by the
user or any classes
* created by the interpreter to the search path
*/
- private def createClassLoader(): MutableURLClassLoader = {
+ private def createClassLoader(currentJars: HashMap[String, Long]):
MutableURLClassLoader = {
// Bootstrap the list of jars with the user class path.
val now = System.currentTimeMillis()
userClassPath.foreach { url =>
currentJars(url.getPath().split("/").last) = now
}
- val currentLoader = Utils.getContextOrSparkClassLoader
-
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
@@ -978,9 +1012,9 @@ private[spark] class Executor(
logInfo(s"Starting executor with user classpath (userClassPathFirst =
$userClassPathFirst): " +
urls.mkString("'", ",", "'"))
if (userClassPathFirst) {
- new ChildFirstURLClassLoader(urls, currentLoader)
+ new ChildFirstURLClassLoader(urls, systemLoader)
} else {
- new MutableURLClassLoader(urls, currentLoader)
+ new MutableURLClassLoader(urls, systemLoader)
}
}
@@ -988,8 +1022,10 @@ private[spark] class Executor(
* If the REPL is in use, add another ClassLoader that will read
* new classes defined by the REPL as the user types code
*/
- private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
- val classUri = conf.get("spark.repl.class.uri", null)
+ private def addReplClassLoaderIfNeeded(
+ parent: ClassLoader,
+ sessionClassUri: Option[String]): ClassLoader = {
+ val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri",
null))
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst)
@@ -1004,9 +1040,10 @@ private[spark] class Executor(
* Visible for testing.
*/
private[executor] def updateDependencies(
- newFiles: Map[String, Long],
- newJars: Map[String, Long],
- newArchives: Map[String, Long],
+ newFiles: immutable.Map[String, Long],
+ newJars: immutable.Map[String, Long],
+ newArchives: immutable.Map[String, Long],
+ state: IsolatedSessionState,
testStartLatch: Option[CountDownLatch] = None,
testEndLatch: Option[CountDownLatch] = None): Unit = {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
@@ -1015,14 +1052,15 @@ private[spark] class Executor(
// For testing, so we can simulate a slow file download:
testStartLatch.foreach(_.countDown())
// Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L)
< timestamp) {
+ for ((name, timestamp) <- newFiles if state.currentFiles.getOrElse(name,
-1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
hadoopConf, timestamp, useCache = !isLocal)
- currentFiles(name) = timestamp
+ state.currentFiles(name) = timestamp
}
- for ((name, timestamp) <- newArchives if currentArchives.getOrElse(name,
-1L) < timestamp) {
+ for ((name, timestamp) <- newArchives if
+ state.currentArchives.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
val sourceURI = new URI(name)
val uriToDownload =
UriBuilder.fromUri(sourceURI).fragment(null).build()
@@ -1035,24 +1073,24 @@ private[spark] class Executor(
s"Unpacking an archive $name from ${source.getAbsolutePath} to
${dest.getAbsolutePath}")
Utils.deleteRecursively(dest)
Utils.unpack(source, dest)
- currentArchives(name) = timestamp
+ state.currentArchives(name) = timestamp
}
for ((name, timestamp) <- newJars) {
val localName = new URI(name).getPath.split("/").last
- val currentTimeStamp = currentJars.get(name)
- .orElse(currentJars.get(localName))
+ val currentTimeStamp = state.currentJars.get(name)
+ .orElse(state.currentJars.get(localName))
.getOrElse(-1L)
if (currentTimeStamp < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
hadoopConf, timestamp, useCache = !isLocal)
- currentJars(name) = timestamp
+ state.currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory(),
localName).toURI.toURL
- if (!urlClassLoader.getURLs().contains(url)) {
+ if (!state.urlClassLoader.getURLs().contains(url)) {
logInfo(s"Adding $url to class loader")
- urlClassLoader.addURL(url)
+ state.urlClassLoader.addURL(url)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
index 790a3a51a09..89f0c05c815 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.util.Properties
+import org.apache.spark.JobArtifactSet
import org.apache.spark.util.CallSite
/**
@@ -38,6 +39,7 @@ import org.apache.spark.util.CallSite
* ShuffleMapStage for submitMapStage).
* @param callSite Where this job was initiated in the user's program (shown
on UI).
* @param listener A listener to notify if tasks in this job finish or the job
fails.
+ * @param artifacts A set of artifacts that this job has may use.
* @param properties Scheduling properties attached to the job, such as fair
scheduler pool name.
*/
private[spark] class ActiveJob(
@@ -45,6 +47,7 @@ private[spark] class ActiveJob(
val finalStage: Stage,
val callSite: CallSite,
val listener: JobListener,
+ val artifacts: JobArtifactSet,
val properties: Properties) {
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 64a8192f8e1..00f505fa5a9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -944,6 +944,7 @@ private[spark] class DAGScheduler(
val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
+ JobArtifactSet.getActiveOrDefault(sc),
Utils.cloneProperties(properties)))
waiter
}
@@ -1023,7 +1024,7 @@ private[spark] class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener,
- clonedProperties))
+ JobArtifactSet.getActiveOrDefault(sc), clonedProperties))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -1065,7 +1066,8 @@ private[spark] class DAGScheduler(
this, jobId, 1,
(_: Int, r: MapOutputStatistics) => callback(r))
eventProcessLoop.post(MapStageSubmitted(
- jobId, dependency, callSite, waiter, Utils.cloneProperties(properties)))
+ jobId, dependency, callSite, waiter,
JobArtifactSet.getActiveOrDefault(sc),
+ Utils.cloneProperties(properties)))
waiter
}
@@ -1268,6 +1270,7 @@ private[spark] class DAGScheduler(
partitions: Array[Int],
callSite: CallSite,
listener: JobListener,
+ artifacts: JobArtifactSet,
properties: Properties): Unit = {
var finalStage: ResultStage = null
try {
@@ -1288,7 +1291,7 @@ private[spark] class DAGScheduler(
messageScheduler.schedule(
new Runnable {
override def run(): Unit =
eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func,
- partitions, callSite, listener, properties))
+ partitions, callSite, listener, artifacts, properties))
},
timeIntervalNumTasksCheck,
TimeUnit.SECONDS
@@ -1309,7 +1312,7 @@ private[spark] class DAGScheduler(
// Job submitted, clear internal data.
barrierJobIdToNumTasksCheckFailures.remove(jobId)
- val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
+ val job = new ActiveJob(jobId, finalStage, callSite, listener, artifacts,
properties)
clearCacheLocs()
logInfo("Got job %s (%s) with %d output partitions".format(
job.jobId, callSite.shortForm, partitions.length))
@@ -1333,6 +1336,7 @@ private[spark] class DAGScheduler(
dependency: ShuffleDependency[_, _, _],
callSite: CallSite,
listener: JobListener,
+ artifacts: JobArtifactSet,
properties: Properties): Unit = {
// Submitting this map stage might still require the creation of some
parent stages, so make
// sure that happens.
@@ -1348,7 +1352,7 @@ private[spark] class DAGScheduler(
return
}
- val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
+ val job = new ActiveJob(jobId, finalStage, callSite, listener, artifacts,
properties)
clearCacheLocs()
logInfo("Got map stage job %s (%s) with %d output partitions".format(
jobId, callSite.shortForm, dependency.rdd.partitions.length))
@@ -1590,6 +1594,8 @@ private[spark] class DAGScheduler(
return
}
+ val artifacts = jobIdToActiveJob(jobId).artifacts
+
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics =
closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
@@ -1600,8 +1606,9 @@ private[spark] class DAGScheduler(
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary,
- part, stage.numPartitions, locs, properties,
serializedTaskMetrics, Option(jobId),
- Option(sc.applicationId), sc.applicationAttemptId,
stage.rdd.isBarrier())
+ part, stage.numPartitions, locs, artifacts, properties,
serializedTaskMetrics,
+ Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
+ stage.rdd.isBarrier())
}
case stage: ResultStage =>
@@ -1610,9 +1617,9 @@ private[spark] class DAGScheduler(
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
- taskBinary, part, stage.numPartitions, locs, id, properties,
serializedTaskMetrics,
- Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
- stage.rdd.isBarrier())
+ taskBinary, part, stage.numPartitions, locs, id, artifacts,
properties,
+ serializedTaskMetrics, Option(jobId), Option(sc.applicationId),
+ sc.applicationAttemptId, stage.rdd.isBarrier())
}
}
} catch {
@@ -2979,11 +2986,13 @@ private[scheduler] class
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
}
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
- case JobSubmitted(jobId, rdd, func, partitions, callSite, listener,
properties) =>
- dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite,
listener, properties)
+ case JobSubmitted(jobId, rdd, func, partitions, callSite, listener,
artifacts, properties) =>
+ dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite,
listener, artifacts,
+ properties)
- case MapStageSubmitted(jobId, dependency, callSite, listener, properties)
=>
- dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite,
listener, properties)
+ case MapStageSubmitted(jobId, dependency, callSite, listener, artifacts,
properties) =>
+ dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite,
listener, artifacts,
+ properties)
case StageCancelled(stageId, reason) =>
dagScheduler.handleStageCancellation(stageId, reason)
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 6f2b778ca82..f8cd2742906 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -39,6 +39,7 @@ private[scheduler] case class JobSubmitted(
partitions: Array[Int],
callSite: CallSite,
listener: JobListener,
+ artifactSet: JobArtifactSet,
properties: Properties = null)
extends DAGSchedulerEvent
@@ -48,6 +49,7 @@ private[scheduler] case class MapStageSubmitted(
dependency: ShuffleDependency[_, _, _],
callSite: CallSite,
listener: JobListener,
+ artifactSet: JobArtifactSet,
properties: Properties = null)
extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index cc3677fc4d4..3eae49aa3b9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -60,14 +60,15 @@ private[spark] class ResultTask[T, U](
numPartitions: Int,
locs: Seq[TaskLocation],
val outputId: Int,
+ artifacts: JobArtifactSet,
localProperties: Properties,
serializedTaskMetrics: Array[Byte],
jobId: Option[Int] = None,
appId: Option[String] = None,
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
- extends Task[U](stageId, stageAttemptId, partition.index, numPartitions,
localProperties,
- serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
+ extends Task[U](stageId, stageAttemptId, partition.index, numPartitions,
artifacts,
+ localProperties, serializedTaskMetrics, jobId, appId, appAttemptId,
isBarrier)
with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index b0687094108..641a900c893 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -39,6 +39,7 @@ import org.apache.spark.rdd.RDD
* @param partition partition of the RDD this task is associated with
* @param numPartitions Total number of partitions in the stage that this task
belongs to.
* @param locs preferred task execution locations for locality scheduling
+ * @param artifacts list of artifacts (may be session-specific) of the job
this task belongs to.
* @param localProperties copy of thread-local properties set by the user on
the driver side.
* @param serializedTaskMetrics a `TaskMetrics` that is created and serialized
on the driver side
* and sent to executor side.
@@ -57,19 +58,21 @@ private[spark] class ShuffleMapTask(
partition: Partition,
numPartitions: Int,
@transient private var locs: Seq[TaskLocation],
+ artifacts: JobArtifactSet,
localProperties: Properties,
serializedTaskMetrics: Array[Byte],
jobId: Option[Int] = None,
appId: Option[String] = None,
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
- extends Task[MapStatus](stageId, stageAttemptId, partition.index,
numPartitions, localProperties,
- serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index,
numPartitions, artifacts,
+ localProperties, serializedTaskMetrics, jobId, appId, appAttemptId,
isBarrier)
with Logging {
/** A constructor used only in test suites. This does not require passing in
an RDD. */
def this(partitionId: Int) = {
- this(0, 0, null, new Partition { override def index: Int = 0 }, 1, null,
new Properties, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, 1, null,
null, new Properties,
+ null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 001e3220e73..39667ea2364 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,6 +45,7 @@ import org.apache.spark.util._
* @param stageAttemptId attempt id of the stage this task belongs to
* @param partitionId index of the number in the RDD
* @param numPartitions Total number of partitions in the stage that this task
belongs to.
+ * @param artifacts list of artifacts (may be session-specific) of the job
this task belongs to.
* @param localProperties copy of thread-local properties set by the user on
the driver side.
* @param serializedTaskMetrics a `TaskMetrics` that is created and serialized
on the driver side
* and sent to executor side.
@@ -61,6 +62,7 @@ private[spark] abstract class Task[T](
val stageAttemptId: Int,
val partitionId: Int,
val numPartitions: Int,
+ val artifacts: JobArtifactSet,
@transient var localProperties: Properties = new Properties,
// The default value is only used in tests.
serializedTaskMetrics: Array[Byte] =
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 88138519983..0e30c165457 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
+import org.apache.spark.JobArtifactSet
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream,
Utils}
@@ -53,9 +54,7 @@ private[spark] class TaskDescription(
val name: String,
val index: Int, // Index within this task's TaskSet
val partitionId: Int,
- val addedFiles: Map[String, Long],
- val addedJars: Map[String, Long],
- val addedArchives: Map[String, Long],
+ val artifacts: JobArtifactSet,
val properties: Properties,
val cpus: Int,
val resources: immutable.Map[String, ResourceInformation],
@@ -97,14 +96,8 @@ private[spark] object TaskDescription {
dataOut.writeInt(taskDescription.index)
dataOut.writeInt(taskDescription.partitionId)
- // Write files.
- serializeStringLongMap(taskDescription.addedFiles, dataOut)
-
- // Write jars.
- serializeStringLongMap(taskDescription.addedJars, dataOut)
-
- // Write archives.
- serializeStringLongMap(taskDescription.addedArchives, dataOut)
+ // Write artifacts
+ serializeArtifacts(taskDescription.artifacts, dataOut)
// Write properties.
dataOut.writeInt(taskDescription.properties.size())
@@ -130,6 +123,38 @@ private[spark] object TaskDescription {
bytesOut.toByteBuffer
}
+ private def deserializeOptionString(in: DataInputStream): Option[String] = {
+ if (in.readBoolean()) {
+ Some(in.readUTF())
+ } else {
+ None
+ }
+ }
+
+ private def deserializeArtifacts(dataIn: DataInputStream): JobArtifactSet = {
+ new JobArtifactSet(
+ uuid = deserializeOptionString(dataIn),
+ replClassDirUri = deserializeOptionString(dataIn),
+ jars = immutable.Map(deserializeStringLongMap(dataIn).toSeq: _*),
+ files = immutable.Map(deserializeStringLongMap(dataIn).toSeq: _*),
+ archives = immutable.Map(deserializeStringLongMap(dataIn).toSeq: _*))
+ }
+
+ private def serializeOptionString(str: Option[String], out:
DataOutputStream): Unit = {
+ out.writeBoolean(str.isDefined)
+ if (str.isDefined) {
+ out.writeUTF(str.get)
+ }
+ }
+
+ private def serializeArtifacts(artifacts: JobArtifactSet, dataOut:
DataOutputStream): Unit = {
+ serializeOptionString(artifacts.uuid, dataOut)
+ serializeOptionString(artifacts.replClassDirUri, dataOut)
+ serializeStringLongMap(Map(artifacts.jars.toSeq: _*), dataOut)
+ serializeStringLongMap(Map(artifacts.files.toSeq: _*), dataOut)
+ serializeStringLongMap(Map(artifacts.archives.toSeq: _*), dataOut)
+ }
+
private def deserializeStringLongMap(dataIn: DataInputStream):
HashMap[String, Long] = {
val map = new HashMap[String, Long]()
val mapSize = dataIn.readInt()
@@ -171,14 +196,8 @@ private[spark] object TaskDescription {
val index = dataIn.readInt()
val partitionId = dataIn.readInt()
- // Read files.
- val taskFiles = deserializeStringLongMap(dataIn)
-
- // Read jars.
- val taskJars = deserializeStringLongMap(dataIn)
-
- // Read archives.
- val taskArchives = deserializeStringLongMap(dataIn)
+ // Read artifacts.
+ val artifacts = deserializeArtifacts(dataIn)
// Read properties.
val properties = new Properties()
@@ -200,7 +219,7 @@ private[spark] object TaskDescription {
// Create a sub-buffer for the serialized task into its own buffer (to be
deserialized later).
val serializedTask = byteBuffer.slice()
- new TaskDescription(taskId, attemptNumber, executorId, name, index,
partitionId, taskFiles,
- taskJars, taskArchives, properties, cpus, resources, serializedTask)
+ new TaskDescription(taskId, attemptNumber, executorId, name, index,
partitionId, artifacts,
+ properties, cpus, resources, serializedTask)
}
}
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 20a1943fa69..69b626029e4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -63,11 +63,6 @@ private[spark] class TaskSetManager(
private val conf = sched.sc.conf
- // SPARK-21563 make a copy of the jars/files so they are consistent across
the TaskSet
- private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*)
- private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*)
- private val addedArchives = HashMap[String,
Long](sched.sc.addedArchives.toSeq: _*)
-
val maxResultSize = conf.get(config.MAX_RESULT_SIZE)
// Serializer for closures and tasks.
@@ -568,9 +563,7 @@ private[spark] class TaskSetManager(
tName,
index,
task.partitionId,
- addedFiles,
- addedJars,
- addedArchives,
+ task.artifacts,
task.localProperties,
taskCpus,
taskResourceAssignments,
diff --git a/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala
b/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala
new file mode 100644
index 00000000000..df09de1483e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+
+class JobArtifactSetSuite extends SparkFunSuite with LocalSparkContext {
+ test("JobArtifactSet uses resources from SparkContext") {
+ withTempDir { dir =>
+ val jarPath = File.createTempFile("testJar", ".jar", dir).getAbsolutePath
+ val filePath = File.createTempFile("testFile", ".txt",
dir).getAbsolutePath
+ val archivePath = File.createTempFile("testZip", ".zip",
dir).getAbsolutePath
+
+ val conf = new SparkConf()
+ .setAppName("test")
+ .setMaster("local")
+ .set("spark.repl.class.uri", "dummyUri")
+ sc = new SparkContext(conf)
+
+ sc.addJar(jarPath)
+ sc.addFile(filePath)
+ sc.addJar(archivePath)
+
+ val artifacts = JobArtifactSet.getActiveOrDefault(sc)
+ assert(artifacts.archives == sc.addedArchives)
+ assert(artifacts.files == sc.addedFiles)
+ assert(artifacts.jars == sc.addedJars)
+ assert(artifacts.replClassDirUri.contains("dummyUri"))
+ }
+ }
+
+ test("The active JobArtifactSet is fetched if set") {
+ withTempDir { dir =>
+ val jarPath = File.createTempFile("testJar", ".jar", dir).getAbsolutePath
+ val filePath = File.createTempFile("testFile", ".txt",
dir).getAbsolutePath
+ val archivePath = File.createTempFile("testZip", ".zip",
dir).getAbsolutePath
+
+ val conf = new SparkConf()
+ .setAppName("test")
+ .setMaster("local")
+ .set("spark.repl.class.uri", "dummyUri")
+ sc = new SparkContext(conf)
+
+ sc.addJar(jarPath)
+ sc.addFile(filePath)
+ sc.addJar(archivePath)
+
+ val artifactSet1 = new JobArtifactSet(
+ Some("123"),
+ Some("abc"),
+ Map("a" -> 1),
+ Map("b" -> 2),
+ Map("c" -> 3)
+ )
+
+ val artifactSet2 = new JobArtifactSet(
+ Some("789"),
+ Some("hjk"),
+ Map("x" -> 7),
+ Map("y" -> 8),
+ Map("z" -> 9)
+ )
+
+ JobArtifactSet.withActive(artifactSet1) {
+ JobArtifactSet.withActive(artifactSet2) {
+ assert(JobArtifactSet.getActiveOrDefault(sc) == artifactSet2)
+ }
+ assert(JobArtifactSet.getActiveOrDefault(sc) == artifactSet1)
+ }
+ }
+ }
+}
diff --git
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
index 7ba5dd4793b..9c61b1f8c27 100644
---
a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
+++
b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala
@@ -24,7 +24,6 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
-import scala.collection.mutable
import scala.concurrent.duration._
import org.json4s.{DefaultFormats, Extraction}
@@ -307,7 +306,7 @@ class CoarseGrainedExecutorBackendSuite extends
SparkFunSuite
// We don't really verify the data, just pass it around.
val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
val taskDescription = new TaskDescription(taskId, 2, "1", "TASK
1000000", 19,
- 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
+ 1, JobArtifactSet(), new Properties, 1,
Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
val serializedTaskDescription = TaskDescription.encode(taskDescription)
backend.rpcEnv.setupEndpoint("Executor 1", backend)
@@ -423,7 +422,7 @@ class CoarseGrainedExecutorBackendSuite extends
SparkFunSuite
// Fake tasks with different taskIds.
val taskDescriptions = (1 to numTasks).map {
taskId => new TaskDescription(taskId, 2, "1", s"TASK $taskId", 19,
- 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
+ 1, JobArtifactSet(), new Properties, 1,
Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
}
assert(taskDescriptions.length == numTasks)
@@ -512,7 +511,7 @@ class CoarseGrainedExecutorBackendSuite extends
SparkFunSuite
// Fake tasks with different taskIds.
val taskDescriptions = (1 to numTasks).map {
taskId => new TaskDescription(taskId, 2, "1", s"TASK $taskId", 19,
- 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new
Properties, 1,
+ 1, JobArtifactSet(), new Properties, 1,
Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data)
}
assert(taskDescriptions.length == numTasks)
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 46f41195ebd..72a6c7555c7 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -26,7 +26,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.immutable
-import scala.collection.mutable.{ArrayBuffer, Map}
+import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import com.google.common.cache.{CacheBuilder, CacheLoader}
@@ -548,9 +548,10 @@ class ExecutorSuite extends SparkFunSuite
// and takes a long time to finish because file download is slow:
val slowLibraryDownloadThread = new Thread(() => {
executor.updateDependencies(
- Map.empty,
- Map.empty,
- Map.empty,
+ immutable.Map.empty,
+ immutable.Map.empty,
+ immutable.Map.empty,
+ executor.defaultSessionState,
Some(startLatch),
Some(endLatch))
})
@@ -563,9 +564,10 @@ class ExecutorSuite extends SparkFunSuite
// dependency update:
val blockedLibraryDownloadThread = new Thread(() => {
executor.updateDependencies(
- Map.empty,
- Map.empty,
- Map.empty)
+ immutable.Map.empty,
+ immutable.Map.empty,
+ immutable.Map.empty,
+ executor.defaultSessionState)
})
blockedLibraryDownloadThread.start()
eventually(timeout(10.seconds), interval(100.millis)) {
@@ -621,6 +623,7 @@ class ExecutorSuite extends SparkFunSuite
numPartitions = 1,
locs = Seq(),
outputId = 0,
+ JobArtifactSet(),
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)
@@ -636,9 +639,7 @@ class ExecutorSuite extends SparkFunSuite
name = "",
index = 0,
partitionId = 0,
- addedFiles = Map[String, Long](),
- addedJars = Map[String, Long](),
- addedArchives = Map[String, Long](),
+ JobArtifactSet(),
properties = new Properties,
cpus = 1,
resources = immutable.Map[String, ResourceInformation](),
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
index c4d4fd7d80e..cb82c2e0a45 100644
---
a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
+++
b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala
@@ -257,9 +257,7 @@ class CoarseGrainedSchedulerBackendSuite extends
SparkFunSuite with LocalSparkCo
val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0")))
val taskCpus = 1
val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1,
0, "1",
- "t1", 0, 1, mutable.Map.empty[String, Long],
- mutable.Map.empty[String, Long], mutable.Map.empty[String, Long],
- new Properties(), taskCpus, taskResources, bytebuffer)))
+ "t1", 0, 1, JobArtifactSet(), new Properties(), taskCpus, taskResources,
bytebuffer)))
val ts = backend.getTaskSchedulerImpl()
when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]],
any[Boolean])).thenReturn(taskDescs)
@@ -365,9 +363,7 @@ class CoarseGrainedSchedulerBackendSuite extends
SparkFunSuite with LocalSparkCo
val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0")))
val taskCpus = 1
val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1,
0, "1",
- "t1", 0, 1, mutable.Map.empty[String, Long],
- mutable.Map.empty[String, Long], mutable.Map.empty[String, Long],
- new Properties(), taskCpus, taskResources, bytebuffer)))
+ "t1", 0, 1, JobArtifactSet(), new Properties(), taskCpus, taskResources,
bytebuffer)))
val ts = backend.getTaskSchedulerImpl()
when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]],
any[Boolean])).thenReturn(taskDescs)
@@ -459,9 +455,7 @@ class CoarseGrainedSchedulerBackendSuite extends
SparkFunSuite with LocalSparkCo
// Task cpus can be different from default resource profile when
TaskResourceProfile is used.
val taskCpus = 2
val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1,
0, "1",
- "t1", 0, 1, mutable.Map.empty[String, Long],
- mutable.Map.empty[String, Long], mutable.Map.empty[String, Long],
- new Properties(), taskCpus, Map.empty, bytebuffer)))
+ "t1", 0, 1, JobArtifactSet(), new Properties(), taskCpus, Map.empty,
bytebuffer)))
when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]],
any[Boolean])).thenReturn(taskDescs)
backend.driverEndpoint.send(ReviveOffers)
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 73ee879ad53..3aeb52cd37d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -491,18 +491,21 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
partitions: Array[Int],
func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
listener: JobListener = jobListener,
+ artifacts: JobArtifactSet = JobArtifactSet(sc),
properties: Properties = null): Int = {
val jobId = scheduler.nextJobId.getAndIncrement()
- runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""),
listener, properties))
+ runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""),
listener, artifacts,
+ properties))
jobId
}
/** Submits a map stage to the scheduler and returns the job id. */
private def submitMapStage(
shuffleDep: ShuffleDependency[_, _, _],
- listener: JobListener = jobListener): Int = {
+ listener: JobListener = jobListener,
+ artifacts: JobArtifactSet = JobArtifactSet(sc)): Int = {
val jobId = scheduler.nextJobId.getAndIncrement()
- runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener))
+ runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener,
artifacts))
jobId
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 6ab56d3fffe..2f65b608a46 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.util.Properties
-import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.{JobArtifactSet, Partition, SparkEnv, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.resource.ResourceProfile
@@ -30,8 +30,8 @@ class FakeTask(
serializedTaskMetrics: Array[Byte] =
SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(),
isBarrier: Boolean = false)
- extends Task[Int](stageId, 0, partitionId, 1, new Properties,
serializedTaskMetrics,
- isBarrier = isBarrier) {
+ extends Task[Int](stageId, 0, partitionId, 1,
JobArtifactSet.defaultArtifactSet(),
+ new Properties, serializedTaskMetrics, isBarrier = isBarrier) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
@@ -96,7 +96,7 @@ object FakeTask {
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new ShuffleMapTask(stageId, stageAttemptId, null, new Partition {
override def index: Int = i
- }, 1, prefLocs(i), new Properties,
+ }, 1, prefLocs(i), JobArtifactSet.defaultArtifactSet(), new Properties,
SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array())
}
new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null,
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 2631ab2a92a..b1e1e9c50a2 100644
---
a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++
b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -19,13 +19,13 @@ package org.apache.spark.scheduler
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
-import org.apache.spark.TaskContext
+import org.apache.spark.{JobArtifactSet, TaskContext}
/**
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0, 1) {
+ extends Task[Array[Byte]](stageId, 0, 0, 1, JobArtifactSet()) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index fcbc734e8bd..f350e3cda51 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary =
sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
- 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, new Properties,
+ 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0,
JobArtifactSet(sc), new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, 1, null, Option.empty)
@@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary =
sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
- 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, new Properties,
+ 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0,
JobArtifactSet(sc), new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, 1, null, Option.empty)
@@ -160,7 +160,8 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
})
val e = intercept[TaskContextSuite.FakeTaskFailureException] {
- context.runTaskWithListeners(new Task[Int](0, 0, 0, 1,
serializedTaskMetrics = Array.empty) {
+ context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(),
+ serializedTaskMetrics = Array.empty) {
override def runTask(context: TaskContext): Int = {
throw new TaskContextSuite.FakeTaskFailureException
}
@@ -191,7 +192,8 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
})
val e = intercept[TaskContextSuite.FakeTaskFailureException] {
- context.runTaskWithListeners(new Task[Int](0, 0, 0, 1,
serializedTaskMetrics = Array.empty) {
+ context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(),
+ serializedTaskMetrics = Array.empty) {
override def runTask(context: TaskContext): Int = {
throw new TaskContextSuite.FakeTaskFailureException
}
@@ -222,7 +224,8 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
})
val e = intercept[TaskCompletionListenerException] {
- context.runTaskWithListeners(new Task[Int](0, 0, 0, 1,
serializedTaskMetrics = Array.empty) {
+ context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(),
+ serializedTaskMetrics = Array.empty) {
override def runTask(context: TaskContext): Int = 0
})
}
@@ -252,7 +255,8 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
})
val e = intercept[TaskCompletionListenerException] {
- context.runTaskWithListeners(new Task[Int](0, 0, 0, 1,
serializedTaskMetrics = Array.empty) {
+ context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(),
+ serializedTaskMetrics = Array.empty) {
override def runTask(context: TaskContext): Int = 0
})
}
@@ -284,7 +288,8 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
})
val e = intercept[TaskCompletionListenerException] {
- context.runTaskWithListeners(new Task[Int](0, 0, 0, 1,
serializedTaskMetrics = Array.empty) {
+ context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(),
+ serializedTaskMetrics = Array.empty) {
override def runTask(context: TaskContext): Int = 0
})
}
@@ -316,7 +321,8 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
})
val e = intercept[TaskContextSuite.FakeTaskFailureException] {
- context.runTaskWithListeners(new Task[Int](0, 0, 0, 1,
serializedTaskMetrics = Array.empty) {
+ context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(),
+ serializedTaskMetrics = Array.empty) {
override def runTask(context: TaskContext): Int = {
throw new TaskContextSuite.FakeTaskFailureException
}
@@ -424,7 +430,7 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
// Create a dummy task. We won't end up running this; we just want to
collect
// accumulator updates from it.
val taskMetrics = TaskMetrics.empty
- val task = new Task[Int](0, 0, 0, 1) {
+ val task = new Task[Int](0, 0, 0, 1, JobArtifactSet(sc)) {
context = new TaskContextImpl(0, 0, 0, 0L, 0, 1,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
@@ -447,7 +453,7 @@ class TaskContextSuite extends SparkFunSuite with
BeforeAndAfter with LocalSpark
// Create a dummy task. We won't end up running this; we just want to
collect
// accumulator updates from it.
val taskMetrics = TaskMetrics.registered
- val task = new Task[Int](0, 0, 0, 1) {
+ val task = new Task[Int](0, 0, 0, 1, JobArtifactSet(sc)) {
context = new TaskContextImpl(0, 0, 0, 0L, 0, 1,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
new Properties,
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
index 25d7ab88426..7f84806e1f8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
@@ -23,7 +23,7 @@ import java.util.Properties
import scala.collection.mutable.HashMap
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{JobArtifactSet, SparkFunSuite}
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.resource.ResourceUtils.GPU
@@ -65,6 +65,14 @@ class TaskDescriptionSuite extends SparkFunSuite {
// Create a dummy byte buffer for the task.
val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
+ val artifacts = new JobArtifactSet(
+ uuid = None,
+ replClassDirUri = None,
+ jars = Map(originalJars.toSeq: _*),
+ files = Map(originalFiles.toSeq: _*),
+ archives = Map(originalArchives.toSeq: _*)
+ )
+
val originalTaskDescription = new TaskDescription(
taskId = 1520589,
attemptNumber = 2,
@@ -72,9 +80,7 @@ class TaskDescriptionSuite extends SparkFunSuite {
name = "task for test",
index = 19,
partitionId = 1,
- originalFiles,
- originalJars,
- originalArchives,
+ artifacts,
originalProperties,
cpus = 2,
originalResources,
@@ -91,9 +97,7 @@ class TaskDescriptionSuite extends SparkFunSuite {
assert(decodedTaskDescription.name === originalTaskDescription.name)
assert(decodedTaskDescription.index === originalTaskDescription.index)
assert(decodedTaskDescription.partitionId ===
originalTaskDescription.partitionId)
- assert(decodedTaskDescription.addedFiles.equals(originalFiles))
- assert(decodedTaskDescription.addedJars.equals(originalJars))
- assert(decodedTaskDescription.addedArchives.equals(originalArchives))
+ assert(decodedTaskDescription.artifacts.equals(artifacts))
assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties))
assert(decodedTaskDescription.cpus.equals(originalTaskDescription.cpus))
assert(equalResources(decodedTaskDescription.resources,
originalTaskDescription.resources))
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 7d2b4f5221a..2dd3b0fda20 100644
---
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -2155,11 +2155,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with
LocalSparkContext
new WorkerOffer("executor1", "host1", 1))
val task1 = new ShuffleMapTask(1, 0, null, new Partition {
override def index: Int = 0
- }, 1, Seq(TaskLocation("host0", "executor0")), new Properties, null)
+ }, 1, Seq(TaskLocation("host0", "executor0")), JobArtifactSet(sc), new
Properties, null)
val task2 = new ShuffleMapTask(1, 0, null, new Partition {
override def index: Int = 1
- }, 1, Seq(TaskLocation("host1", "executor1")), new Properties, null)
+ }, 1, Seq(TaskLocation("host1", "executor1")), JobArtifactSet(sc), new
Properties, null)
val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0, Some(0))
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index cb70dbb0289..10c1a72066f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -184,7 +184,7 @@ class FakeTaskScheduler(
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, 1) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, 1,
JobArtifactSet()) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KIB *
1024)
val random = new Random(0)
@@ -900,7 +900,7 @@ class TaskSetManagerSuite
val singleTask = new ShuffleMapTask(0, 0, null, new Partition {
override def index: Int = 0
- }, 1, Seq(TaskLocation("host1", "execA")), new Properties, null)
+ }, 1, Seq(TaskLocation("host1", "execA")), JobArtifactSet(sc), new
Properties, null)
val taskSet = new TaskSet(Array(singleTask), 0, 0, 0,
null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, Some(0))
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
@@ -1528,9 +1528,9 @@ class TaskSetManagerSuite
// all tasks from the first taskset have the same jars
val taskOption1 = manager1.resourceOffer("exec1", "host1", NO_PREF)._1
- assert(taskOption1.get.addedJars === addedJarsPreTaskSet)
+ assert(taskOption1.get.artifacts.jars === addedJarsPreTaskSet)
val taskOption2 = manager1.resourceOffer("exec1", "host1", NO_PREF)._1
- assert(taskOption2.get.addedJars === addedJarsPreTaskSet)
+ assert(taskOption2.get.artifacts.jars === addedJarsPreTaskSet)
// even with a jar added mid-TaskSet
val jarPath =
Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar")
@@ -1539,14 +1539,14 @@ class TaskSetManagerSuite
assert(addedJarsPreTaskSet !== addedJarsMidTaskSet)
val taskOption3 = manager1.resourceOffer("exec1", "host1", NO_PREF)._1
// which should have the old version of the jars list
- assert(taskOption3.get.addedJars === addedJarsPreTaskSet)
+ assert(taskOption3.get.artifacts.jars === addedJarsPreTaskSet)
// and then the jar does appear in the next TaskSet
val taskSet2 = FakeTask.createTaskSet(1)
val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES,
clock = new ManualClock)
val taskOption4 = manager2.resourceOffer("exec1", "host1", NO_PREF)._1
- assert(taskOption4.get.addedJars === addedJarsMidTaskSet)
+ assert(taskOption4.get.artifacts.jars === addedJarsMidTaskSet)
}
test("SPARK-24677: Avoid NoSuchElementException from MedianHeap") {
diff --git
a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
index fa4e800eb36..114b667e6a4 100644
---
a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
+++
b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
@@ -36,7 +36,7 @@ import org.mockito.ArgumentMatchers.{any, anyLong, eq => meq}
import org.mockito.Mockito._
import org.scalatestplus.mockito.MockitoSugar
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext,
+import org.apache.spark.{JobArtifactSet, LocalSparkContext, SparkConf,
SparkContext,
SparkFunSuite}
import org.apache.spark.deploy.mesos.config._
import org.apache.spark.executor.MesosExecutorBackend
@@ -262,9 +262,7 @@ class MesosFineGrainedSchedulerBackendSuite
name = "n1",
index = 0,
partitionId = 0,
- addedFiles = mutable.Map.empty[String, Long],
- addedJars = mutable.Map.empty[String, Long],
- addedArchives = mutable.Map.empty[String, Long],
+ artifacts = JobArtifactSet(),
properties = new Properties(),
cpus = 1,
resources = immutable.Map.empty[String, ResourceInformation],
@@ -377,9 +375,7 @@ class MesosFineGrainedSchedulerBackendSuite
name = "n1",
index = 0,
partitionId = 0,
- addedFiles = mutable.Map.empty[String, Long],
- addedJars = mutable.Map.empty[String, Long],
- addedArchives = mutable.Map.empty[String, Long],
+ artifacts = JobArtifactSet(),
properties = new Properties(),
cpus = 1,
resources = immutable.Map.empty[String, ResourceInformation],
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]