This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6881ec0ebfab [SPARK-50366][SQL] Isolate user-defined tags on thread
level for SparkSession in Classic
6881ec0ebfab is described below
commit 6881ec0ebfab36a59296157458a928331c1d5103
Author: Paddy Xu <[email protected]>
AuthorDate: Fri Nov 22 09:23:34 2024 +0900
[SPARK-50366][SQL] Isolate user-defined tags on thread level for
SparkSession in Classic
### What changes were proposed in this pull request?
This PR changes the implementation of user-provided tags to be
thread-local, so that tags added by two threads to the same SparkSession do not
interfere with each other.
Overlaps (from the `SparkContext` perspective) are avoided by introducing a
thread-local random UUID which is attached to all tags in the same thread.
### Why are the changes needed?
To make tags isolated per thread.
### Does this PR introduce _any_ user-facing change?
Yes, user-provided tags are now isolated on the session level.
### How was this patch tested?
Local test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48906 from xupefei/thread-isolated-tags.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 47 +++++++---
.../apache/spark/sql/execution/SQLExecution.scala | 2 +-
...parkSessionJobTaggingAndCancellationSuite.scala | 102 +++++++++++++++++----
3 files changed, 120 insertions(+), 31 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index afc0a2d7df60..a7f85db12b21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
-import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.mutable
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
@@ -133,14 +133,34 @@ class SparkSession private(
/** Tag to mark all jobs owned by this session. */
private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"
+ /**
+ * A UUID that is unique on the thread level. Used by managedJobTags to make
sure that a same
+ * tag from two threads does not overlap in the underlying
SparkContext/SQLExecution.
+ */
+ private[sql] lazy val threadUuid = new InheritableThreadLocal[String] {
+ override def childValue(parent: String): String = parent
+
+ override def initialValue(): String = UUID.randomUUID().toString
+ }
+
/**
* A map to hold the mapping from user-defined tags to the real tags
attached to Jobs.
- * Real tag have the current session ID attached: `"tag1" ->
s"spark-session-$sessionUUID-tag1"`.
+ * Real tag have the current session ID attached:
+ * tag1 -> spark-session-$sessionUUID-thread-$threadUuid-tag1
+ *
*/
@transient
- private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
- new ConcurrentHashMap(parentManagedJobTags.asJava)
- }
+ private[sql] lazy val managedJobTags = new
InheritableThreadLocal[mutable.Map[String, String]] {
+ override def childValue(parent: mutable.Map[String, String]):
mutable.Map[String, String] = {
+ // Note: make a clone such that changes in the parent tags aren't
reflected in
+ // those of the children threads.
+ parent.clone()
+ }
+
+ override def initialValue(): mutable.Map[String, String] = {
+ mutable.Map(parentManagedJobTags.toSeq: _*)
+ }
+ }
/** @inheritdoc */
def version: String = SPARK_VERSION
@@ -243,10 +263,10 @@ class SparkSession private(
Some(sessionState),
extensions,
Map.empty,
- managedJobTags.asScala.toMap)
+ managedJobTags.get().toMap)
result.sessionState // force copy of SessionState
result.sessionState.artifactManager // force copy of ArtifactManager and
its resources
- result.managedJobTags // force copy of userDefinedToRealTagsMap
+ result.managedJobTags // force copy of managedJobTags
result
}
@@ -550,17 +570,17 @@ class SparkSession private(
/** @inheritdoc */
override def addTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
- managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
+ managedJobTags.get().put(tag,
s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag")
}
/** @inheritdoc */
- override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
+ override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag)
/** @inheritdoc */
- override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
+ override def getTags(): Set[String] = managedJobTags.get().keySet.toSet
/** @inheritdoc */
- override def clearTags(): Unit = managedJobTags.clear()
+ override def clearTags(): Unit = managedJobTags.get().clear()
/**
* Request to interrupt all currently running SQL operations of this session.
@@ -589,9 +609,8 @@ class SparkSession private(
* @since 4.0.0
*/
override def interruptTag(tag: String): Seq[String] = {
- val realTag = managedJobTags.get(tag)
- if (realTag == null) return Seq.empty
- doInterruptTag(realTag, s"part of cancelled job tags $tag")
+ val realTag = managedJobTags.get().get(tag)
+ realTag.map(doInterruptTag(_, s"part of cancelled job tags
$tag")).getOrElse(Seq.empty)
}
private def doInterruptTag(tag: String, reason: String): Seq[String] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index e805aabe013c..242149010cee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -261,7 +261,7 @@ object SQLExecution extends Logging {
}
private[sql] def withSessionTagsApplied[T](sparkSession:
SparkSession)(block: => T): T = {
- val allTags = sparkSession.managedJobTags.values().asScala.toSet +
sparkSession.sessionJobTag
+ val allTags = sparkSession.managedJobTags.get().values.toSet +
sparkSession.sessionJobTag
sparkSession.sparkContext.addJobTags(allTags)
try {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
index 1ac51b408301..89500fe51f3a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.{ExecutionContext, Future}
@@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
val activeJobsFuture =
-
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"),
"reason")
+ session.sparkContext.cancelJobsWithTagWithFuture(
+ session.managedJobTags.get()("one"), "reason")
val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
val actualTags =
activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
.split(SparkContext.SPARK_JOB_TAGS_SEP)
assert(actualTags.toSet == Set(
session.sessionJobTag,
- s"${session.sessionJobTag}-one",
+ s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one",
SQLExecution.executionIdJobTag(
session,
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
@@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite
val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
var (sessionA, sessionB, sessionC): (SparkSession, SparkSession,
SparkSession) =
(null, null, null)
+ var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) =
(null, null, null)
// global ExecutionContext has only 2 threads in Apache Spark CI
// create own thread pool for four Futures used in this test
- val numThreads = 3
- val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool",
numThreads)
- val executionContext = ExecutionContext.fromExecutorService(fpool)
+ val threadPool = Executors.newFixedThreadPool(3)
+ implicit val ec: ExecutionContext =
ExecutionContext.fromExecutor(threadPool)
try {
// Add a listener to release the semaphore once jobs are launched.
@@ -143,28 +144,50 @@ class SparkSessionJobTaggingAndCancellationSuite
}
})
+ var realTagOneForSessionA: String = null
+ var childThread: Thread = null
+ val childThreadLock = new Semaphore(0)
+
// Note: since tags are added in the Future threads, they don't need to
be cleared in between.
val jobA = Future {
sessionA = globalSession.cloneSession()
import globalSession.implicits._
+ threadUuidA = sessionA.threadUuid.get()
assert(sessionA.getTags() == Set())
sessionA.addTag("two")
assert(sessionA.getTags() == Set("two"))
sessionA.clearTags() // check that clearing all tags works
assert(sessionA.getTags() == Set())
sessionA.addTag("one")
+ realTagOneForSessionA = sessionA.managedJobTags.get()("one")
+ assert(realTagOneForSessionA ==
+ s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one")
assert(sessionA.getTags() == Set("one"))
+
+ // Create a child thread which inherits thread-local variables and
tries to interrupt
+ // the job started from the parent thread. The child thread is blocked
until the main
+ // thread releases the lock.
+ childThread = new Thread {
+ override def run(): Unit = {
+ assert(childThreadLock.tryAcquire(1, 20, TimeUnit.SECONDS))
+ assert(sessionA.getTags() == Set("one"))
+ assert(sessionA.interruptTag("one").size == 1)
+ }
+ }
+ childThread.start()
try {
sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
} finally {
+ childThread.interrupt()
sessionA.clearTags() // clear for the case of thread reuse by
another Future
}
- }(executionContext)
+ }
val jobB = Future {
sessionB = globalSession.cloneSession()
import globalSession.implicits._
+ threadUuidB = sessionB.threadUuid.get()
assert(sessionB.getTags() == Set())
sessionB.addTag("one")
sessionB.addTag("two")
@@ -176,11 +199,12 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionB.clearTags() // clear for the case of thread reuse by
another Future
}
- }(executionContext)
+ }
val jobC = Future {
sessionC = globalSession.cloneSession()
import globalSession.implicits._
+ threadUuidC = sessionC.threadUuid.get()
sessionC.addTag("foo")
sessionC.removeTag("foo")
assert(sessionC.getTags() == Set()) // check that remove works
removing the last tag
@@ -190,12 +214,13 @@ class SparkSessionJobTaggingAndCancellationSuite
} finally {
sessionC.clearTags() // clear for the case of thread reuse by
another Future
}
- }(executionContext)
+ }
// Block until four jobs have started.
assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))
// Tags are applied
+ def realUserTag(s: String, t: String, ta: String): String =
s"spark-session-$s-thread-$t-$ta"
assert(jobProperties.size == 3)
for (ss <- Seq(sessionA, sessionB, sessionC)) {
val jobProperty =
jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
@@ -207,15 +232,17 @@ class SparkSessionJobTaggingAndCancellationSuite
val executionRootIdTag = SQLExecution.executionIdJobTag(
ss,
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
- val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"
ss match {
case s if s == sessionA => assert(tags.toSet == Set(
- s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
+ s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID,
threadUuidA, "one")))
case s if s == sessionB => assert(tags.toSet == Set(
- s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one",
s"${userTagsPrefix}two"))
+ s.sessionJobTag,
+ executionRootIdTag,
+ realUserTag(s.sessionUUID, threadUuidB, "one"),
+ realUserTag(s.sessionUUID, threadUuidB, "two")))
case s if s == sessionC => assert(tags.toSet == Set(
- s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
+ s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID,
threadUuidC, "boo")))
}
}
@@ -239,8 +266,10 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 1)
- // Another job cancelled
- assert(sessionA.interruptTag("one").size == 1)
+ // Another job cancelled. The next line cancels nothing because we're
now in another thread.
+ // The real cancel is done through unblocking a child thread, which is
waiting for a lock
+ assert(sessionA.interruptTag("one").isEmpty)
+ childThreadLock.release()
val eA = intercept[SparkException] {
ThreadUtils.awaitResult(jobA, 1.minute)
}.getCause
@@ -257,7 +286,48 @@ class SparkSessionJobTaggingAndCancellationSuite
assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
assert(jobEnded.intValue == 3)
} finally {
- fpool.shutdownNow()
+ threadPool.shutdownNow()
+ }
+ }
+
+ test("Tags are isolated in multithreaded environment") {
+ // Custom thread pool for multi-threaded testing
+ val threadPool = Executors.newFixedThreadPool(2)
+ implicit val ec: ExecutionContext =
ExecutionContext.fromExecutor(threadPool)
+
+ val session = SparkSession.builder().master("local").getOrCreate()
+ @volatile var output1: Set[String] = null
+ @volatile var output2: Set[String] = null
+
+ def tag1(): Unit = {
+ session.addTag("tag1")
+ output1 = session.getTags()
+ }
+
+ def tag2(): Unit = {
+ session.addTag("tag2")
+ output2 = session.getTags()
+ }
+
+ try {
+ // Run tasks in separate threads
+ val future1 = Future {
+ tag1()
+ }
+ val future2 = Future {
+ tag2()
+ }
+
+ // Wait for threads to complete
+ ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute)
+
+ // Assert outputs
+ assert(output1 != null)
+ assert(output1 == Set("tag1"))
+ assert(output2 != null)
+ assert(output2 == Set("tag2"))
+ } finally {
+ threadPool.shutdownNow()
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]