This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 37e4c2d1883d [SPARK-55646][SQL] Refactored 
SQLExecution.withThreadLocalCaptured to separate thread-local capture from 
execution
37e4c2d1883d is described below

commit 37e4c2d1883d8ce356cd5dbae555571443c8115a
Author: huanliwang-db <[email protected]>
AuthorDate: Tue Feb 24 12:32:37 2026 +0800

    [SPARK-55646][SQL] Refactored SQLExecution.withThreadLocalCaptured to 
separate thread-local capture from execution
    
    ### What changes were proposed in this pull request?
    
    Previously, callers had to provide an ExecutorService upfront: thread-local 
capture and task submission were fused into a single call that immediately 
returned a CompletableFuture.
    
    Now, captureThreadLocals(sparkSession) captures the current thread's SQL 
context into a standalone SQLExecutionThreadLocalCaptured object. Callers can 
then invoke `runWith { body }` on any thread, at any time, using any 
concurrency primitive — not just ExecutorService.
    
    `withThreadLocalCaptured` is preserved for backward compatibility and now 
delegates to these two primitives.
    
    ### Why are the changes needed?
    
    Refactoring to make withThreadLocalCaptured easier to use.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    Existing UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #54434 from huanliwang-db/huanliwang-db/refactor-sqlthread.
    
    Authored-by: huanliwang-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../apache/spark/sql/execution/SQLExecution.scala  | 76 +++++++++++++++-------
 1 file changed, 51 insertions(+), 25 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 96a0053f97b1..f25e908a9cdb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong
 import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
-import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkContext, 
SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper}
+import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, JobArtifactState, 
SparkContext, SparkEnv, SparkException, SparkThrowable, SparkThrowableHelper}
 import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION, 
SPARK_JOB_INTERRUPT_ON_CANCEL}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, 
SPARK_EXECUTOR_PREFIX}
@@ -38,6 +38,43 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH
 import org.apache.spark.util.{Utils, UUIDv7Generator}
 
+/**
+ * Captures SQL-specific thread-local variables so they can be restored on a 
different thread.
+ * Use [[SQLExecution.captureThreadLocals]] to create an instance on the 
originating thread,
+ * then call [[runWith]] on the target thread to execute a block with these 
thread locals applied.
+ */
+case class SQLExecutionThreadLocalCaptured(
+  sparkSession: SparkSession,
+  localProps: java.util.Properties,
+  artifactState: JobArtifactState) {
+
+  /**
+   * Run the given body with the captured thread-local variables applied on 
the current thread.
+   * Original thread-local values are saved and restored after the body 
completes.
+   */
+  def runWith[T](body: => T): T = {
+    val sc = sparkSession.sparkContext
+    JobArtifactSet.withActiveJobArtifactState(artifactState) {
+      val originalSession = SparkSession.getActiveSession
+      val originalLocalProps = sc.getLocalProperties
+      SparkSession.setActiveSession(sparkSession)
+      val res = SQLExecution.withSessionTagsApplied(sparkSession) {
+        sc.setLocalProperties(localProps)
+        val res = body
+        // reset active session and local props.
+        sc.setLocalProperties(originalLocalProps)
+        res
+      }
+      if (originalSession.nonEmpty) {
+        SparkSession.setActiveSession(originalSession.get)
+      } else {
+        SparkSession.clearActiveSession()
+      }
+      res
+    }
+  }
+}
+
 object SQLExecution extends Logging {
 
   val EXECUTION_ID_KEY = "spark.sql.execution.id"
@@ -343,36 +380,25 @@ object SQLExecution extends Logging {
     }
   }
 
+  def captureThreadLocals(sparkSession: SparkSession): 
SQLExecutionThreadLocalCaptured = {
+    val sc = sparkSession.sparkContext
+    val localProps = Utils.cloneProperties(sc.getLocalProperties)
+    // `getCurrentJobArtifactState` will return a stat only in Spark Connect 
mode. In non-Connect
+    // mode, we default back to the resources of the current Spark session.
+    val artifactState =
+      
JobArtifactSet.getCurrentJobArtifactState.getOrElse(sparkSession.artifactManager.state)
+    SQLExecutionThreadLocalCaptured(sparkSession, localProps, artifactState)
+  }
+
   /**
    * Wrap passed function to ensure necessary thread-local variables like
    * SparkContext local properties are forwarded to execution thread
    */
   def withThreadLocalCaptured[T](
       sparkSession: SparkSession, exec: ExecutorService) (body: => T): 
CompletableFuture[T] = {
-    val activeSession = sparkSession
-    val sc = sparkSession.sparkContext
-    val localProps = Utils.cloneProperties(sc.getLocalProperties)
-    // `getCurrentJobArtifactState` will return a stat only in Spark Connect 
mode. In non-Connect
-    // mode, we default back to the resources of the current Spark session.
-    val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
-      activeSession.artifactManager.state)
-    CompletableFuture.supplyAsync(() => 
JobArtifactSet.withActiveJobArtifactState(artifactState) {
-      val originalSession = SparkSession.getActiveSession
-      val originalLocalProps = sc.getLocalProperties
-      SparkSession.setActiveSession(activeSession)
-      val res = withSessionTagsApplied(activeSession) {
-        sc.setLocalProperties(localProps)
-        val res = body
-        // reset active session and local props.
-        sc.setLocalProperties(originalLocalProps)
-        res
-      }
-      if (originalSession.nonEmpty) {
-        SparkSession.setActiveSession(originalSession.get)
-      } else {
-        SparkSession.clearActiveSession()
-      }
-      res
+    val threadLocalCaptured = captureThreadLocals(sparkSession)
+    CompletableFuture.supplyAsync(() => {
+      threadLocalCaptured.runWith(body)
     }, exec)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to