This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6881ec0ebfab [SPARK-50366][SQL] Isolate user-defined tags on thread 
level for SparkSession in Classic
6881ec0ebfab is described below

commit 6881ec0ebfab36a59296157458a928331c1d5103
Author: Paddy Xu <[email protected]>
AuthorDate: Fri Nov 22 09:23:34 2024 +0900

    [SPARK-50366][SQL] Isolate user-defined tags on thread level for 
SparkSession in Classic
    
    ### What changes were proposed in this pull request?
    
    This PR changes the implementation of user-provided tags to be 
thread-local, so that tags added by two threads to the same SparkSession do not 
interfere with each other.
    
    Overlaps (from the `SparkContext` perspective) are avoided by introducing a 
thread-local random UUID which is attached to all tags in the same thread.
    
    ### Why are the changes needed?
    
    To make tags isolated per thread.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, user-provided tags are now isolated on the session level.
    
    ### How was this patch tested?
    
    Local test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48906 from xupefei/thread-isolated-tags.
    
    Authored-by: Paddy Xu <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |  47 +++++++---
 .../apache/spark/sql/execution/SQLExecution.scala  |   2 +-
 ...parkSessionJobTaggingAndCancellationSuite.scala | 102 +++++++++++++++++----
 3 files changed, 120 insertions(+), 31 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index afc0a2d7df60..a7f85db12b21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql
 import java.net.URI
 import java.nio.file.Paths
 import java.util.{ServiceLoader, UUID}
-import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicBoolean
 
+import scala.collection.mutable
 import scala.concurrent.duration.DurationInt
 import scala.jdk.CollectionConverters._
 import scala.reflect.runtime.universe.TypeTag
@@ -133,14 +133,34 @@ class SparkSession private(
   /** Tag to mark all jobs owned by this session. */
   private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"
 
+  /**
+   * A UUID that is unique on the thread level. Used by managedJobTags to make 
sure that a same
+   * tag from two threads does not overlap in the underlying 
SparkContext/SQLExecution.
+   */
+  private[sql] lazy val threadUuid = new InheritableThreadLocal[String] {
+    override def childValue(parent: String): String = parent
+
+    override def initialValue(): String = UUID.randomUUID().toString
+  }
+
   /**
    * A map to hold the mapping from user-defined tags to the real tags 
attached to Jobs.
-   * Real tag have the current session ID attached: `"tag1" -> 
s"spark-session-$sessionUUID-tag1"`.
+   * Real tag have the current session ID attached:
+   *   tag1 -> spark-session-$sessionUUID-thread-$threadUuid-tag1
+   *
    */
   @transient
-  private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
-    new ConcurrentHashMap(parentManagedJobTags.asJava)
-  }
+  private[sql] lazy val managedJobTags = new 
InheritableThreadLocal[mutable.Map[String, String]] {
+      override def childValue(parent: mutable.Map[String, String]): 
mutable.Map[String, String] = {
+        // Note: make a clone such that changes in the parent tags aren't 
reflected in
+        // those of the children threads.
+        parent.clone()
+      }
+
+      override def initialValue(): mutable.Map[String, String] = {
+        mutable.Map(parentManagedJobTags.toSeq: _*)
+      }
+    }
 
   /** @inheritdoc */
   def version: String = SPARK_VERSION
@@ -243,10 +263,10 @@ class SparkSession private(
       Some(sessionState),
       extensions,
       Map.empty,
-      managedJobTags.asScala.toMap)
+      managedJobTags.get().toMap)
     result.sessionState // force copy of SessionState
     result.sessionState.artifactManager // force copy of ArtifactManager and 
its resources
-    result.managedJobTags // force copy of userDefinedToRealTagsMap
+    result.managedJobTags // force copy of managedJobTags
     result
   }
 
@@ -550,17 +570,17 @@ class SparkSession private(
   /** @inheritdoc */
   override def addTag(tag: String): Unit = {
     SparkContext.throwIfInvalidTag(tag)
-    managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
+    managedJobTags.get().put(tag, 
s"spark-session-$sessionUUID-thread-${threadUuid.get()}-$tag")
   }
 
   /** @inheritdoc */
-  override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
+  override def removeTag(tag: String): Unit = managedJobTags.get().remove(tag)
 
   /** @inheritdoc */
-  override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
+  override def getTags(): Set[String] = managedJobTags.get().keySet.toSet
 
   /** @inheritdoc */
-  override def clearTags(): Unit = managedJobTags.clear()
+  override def clearTags(): Unit = managedJobTags.get().clear()
 
   /**
    * Request to interrupt all currently running SQL operations of this session.
@@ -589,9 +609,8 @@ class SparkSession private(
    * @since 4.0.0
    */
   override def interruptTag(tag: String): Seq[String] = {
-    val realTag = managedJobTags.get(tag)
-    if (realTag == null) return Seq.empty
-    doInterruptTag(realTag, s"part of cancelled job tags $tag")
+    val realTag = managedJobTags.get().get(tag)
+    realTag.map(doInterruptTag(_, s"part of cancelled job tags 
$tag")).getOrElse(Seq.empty)
   }
 
   private def doInterruptTag(tag: String, reason: String): Seq[String] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index e805aabe013c..242149010cee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -261,7 +261,7 @@ object SQLExecution extends Logging {
   }
 
   private[sql] def withSessionTagsApplied[T](sparkSession: 
SparkSession)(block: => T): T = {
-    val allTags = sparkSession.managedJobTags.values().asScala.toSet + 
sparkSession.sessionJobTag
+    val allTags = sparkSession.managedJobTags.get().values.toSet + 
sparkSession.sessionJobTag
     sparkSession.sparkContext.addJobTags(allTags)
 
     try {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
index 1ac51b408301..89500fe51f3a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, Executors, Semaphore, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.concurrent.{ExecutionContext, Future}
@@ -100,13 +100,14 @@ class SparkSessionJobTaggingAndCancellationSuite
 
     assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
     val activeJobsFuture =
-      
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"),
 "reason")
+      session.sparkContext.cancelJobsWithTagWithFuture(
+        session.managedJobTags.get()("one"), "reason")
     val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
     val actualTags = 
activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
       .split(SparkContext.SPARK_JOB_TAGS_SEP)
     assert(actualTags.toSet == Set(
       session.sessionJobTag,
-      s"${session.sessionJobTag}-one",
+      s"${session.sessionJobTag}-thread-${session.threadUuid.get()}-one",
       SQLExecution.executionIdJobTag(
         session,
         
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
@@ -118,12 +119,12 @@ class SparkSessionJobTaggingAndCancellationSuite
     val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
     var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, 
SparkSession) =
       (null, null, null)
+    var (threadUuidA, threadUuidB, threadUuidC): (String, String, String) = 
(null, null, null)
 
     // global ExecutionContext has only 2 threads in Apache Spark CI
     // create own thread pool for four Futures used in this test
-    val numThreads = 3
-    val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", 
numThreads)
-    val executionContext = ExecutionContext.fromExecutorService(fpool)
+    val threadPool = Executors.newFixedThreadPool(3)
+    implicit val ec: ExecutionContext = 
ExecutionContext.fromExecutor(threadPool)
 
     try {
       // Add a listener to release the semaphore once jobs are launched.
@@ -143,28 +144,50 @@ class SparkSessionJobTaggingAndCancellationSuite
         }
       })
 
+      var realTagOneForSessionA: String = null
+      var childThread: Thread = null
+      val childThreadLock = new Semaphore(0)
+
       // Note: since tags are added in the Future threads, they don't need to 
be cleared in between.
       val jobA = Future {
         sessionA = globalSession.cloneSession()
         import globalSession.implicits._
 
+        threadUuidA = sessionA.threadUuid.get()
         assert(sessionA.getTags() == Set())
         sessionA.addTag("two")
         assert(sessionA.getTags() == Set("two"))
         sessionA.clearTags() // check that clearing all tags works
         assert(sessionA.getTags() == Set())
         sessionA.addTag("one")
+        realTagOneForSessionA = sessionA.managedJobTags.get()("one")
+        assert(realTagOneForSessionA ==
+          s"${sessionA.sessionJobTag}-thread-${sessionA.threadUuid.get()}-one")
         assert(sessionA.getTags() == Set("one"))
+
+        // Create a child thread which inherits thread-local variables and 
tries to interrupt
+        // the job started from the parent thread. The child thread is blocked 
until the main
+        // thread releases the lock.
+        childThread = new Thread {
+          override def run(): Unit = {
+            assert(childThreadLock.tryAcquire(1, 20, TimeUnit.SECONDS))
+            assert(sessionA.getTags() == Set("one"))
+            assert(sessionA.interruptTag("one").size == 1)
+          }
+        }
+        childThread.start()
         try {
           sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
         } finally {
+          childThread.interrupt()
           sessionA.clearTags() // clear for the case of thread reuse by 
another Future
         }
-      }(executionContext)
+      }
       val jobB = Future {
         sessionB = globalSession.cloneSession()
         import globalSession.implicits._
 
+        threadUuidB = sessionB.threadUuid.get()
         assert(sessionB.getTags() == Set())
         sessionB.addTag("one")
         sessionB.addTag("two")
@@ -176,11 +199,12 @@ class SparkSessionJobTaggingAndCancellationSuite
         } finally {
           sessionB.clearTags() // clear for the case of thread reuse by 
another Future
         }
-      }(executionContext)
+      }
       val jobC = Future {
         sessionC = globalSession.cloneSession()
         import globalSession.implicits._
 
+        threadUuidC = sessionC.threadUuid.get()
         sessionC.addTag("foo")
         sessionC.removeTag("foo")
         assert(sessionC.getTags() == Set()) // check that remove works 
removing the last tag
@@ -190,12 +214,13 @@ class SparkSessionJobTaggingAndCancellationSuite
         } finally {
           sessionC.clearTags() // clear for the case of thread reuse by 
another Future
         }
-      }(executionContext)
+      }
 
       // Block until four jobs have started.
       assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))
 
       // Tags are applied
+      def realUserTag(s: String, t: String, ta: String): String = 
s"spark-session-$s-thread-$t-$ta"
       assert(jobProperties.size == 3)
       for (ss <- Seq(sessionA, sessionB, sessionC)) {
         val jobProperty = 
jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
@@ -207,15 +232,17 @@ class SparkSessionJobTaggingAndCancellationSuite
         val executionRootIdTag = SQLExecution.executionIdJobTag(
           ss,
           
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
-        val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"
 
         ss match {
           case s if s == sessionA => assert(tags.toSet == Set(
-            s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
+            s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, 
threadUuidA, "one")))
           case s if s == sessionB => assert(tags.toSet == Set(
-            s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", 
s"${userTagsPrefix}two"))
+            s.sessionJobTag,
+            executionRootIdTag,
+            realUserTag(s.sessionUUID, threadUuidB, "one"),
+            realUserTag(s.sessionUUID, threadUuidB, "two")))
           case s if s == sessionC => assert(tags.toSet == Set(
-            s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
+            s.sessionJobTag, executionRootIdTag, realUserTag(s.sessionUUID, 
threadUuidC, "boo")))
         }
       }
 
@@ -239,8 +266,10 @@ class SparkSessionJobTaggingAndCancellationSuite
       assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
       assert(jobEnded.intValue == 1)
 
-      // Another job cancelled
-      assert(sessionA.interruptTag("one").size == 1)
+      // Another job cancelled. The next line cancels nothing because we're 
now in another thread.
+      // The real cancel is done through unblocking a child thread, which is 
waiting for a lock
+      assert(sessionA.interruptTag("one").isEmpty)
+      childThreadLock.release()
       val eA = intercept[SparkException] {
         ThreadUtils.awaitResult(jobA, 1.minute)
       }.getCause
@@ -257,7 +286,48 @@ class SparkSessionJobTaggingAndCancellationSuite
       assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
       assert(jobEnded.intValue == 3)
     } finally {
-      fpool.shutdownNow()
+      threadPool.shutdownNow()
+    }
+  }
+
+  test("Tags are isolated in multithreaded environment") {
+    // Custom thread pool for multi-threaded testing
+    val threadPool = Executors.newFixedThreadPool(2)
+    implicit val ec: ExecutionContext = 
ExecutionContext.fromExecutor(threadPool)
+
+    val session = SparkSession.builder().master("local").getOrCreate()
+    @volatile var output1: Set[String] = null
+    @volatile var output2: Set[String] = null
+
+    def tag1(): Unit = {
+      session.addTag("tag1")
+      output1 = session.getTags()
+    }
+
+    def tag2(): Unit = {
+      session.addTag("tag2")
+      output2 = session.getTags()
+    }
+
+    try {
+      // Run tasks in separate threads
+      val future1 = Future {
+        tag1()
+      }
+      val future2 = Future {
+        tag2()
+      }
+
+      // Wait for threads to complete
+      ThreadUtils.awaitResult(Future.sequence(Seq(future1, future2)), 1.minute)
+
+      // Assert outputs
+      assert(output1 != null)
+      assert(output1 == Set("tag1"))
+      assert(output2 != null)
+      assert(output2 == Set("tag2"))
+    } finally {
+      threadPool.shutdownNow()
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to