This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new e3ac0d3d6819 [SPARK-57196][SQL] Make UnionExec whole-stage codegen 
thread-safe
e3ac0d3d6819 is described below

commit e3ac0d3d6819f828fc0d88d01b134e9932eb06ca
Author: Gengliang Wang <[email protected]>
AuthorDate: Wed Jun 3 08:14:54 2026 -0700

    [SPARK-57196][SQL] Make UnionExec whole-stage codegen thread-safe
    
    ### What changes were proposed in this pull request?
    
    `UnionExec` whole-stage codegen fusion (SPARK-56482) kept per-emission 
codegen state in mutable instance fields on the plan node: 
`currentEmittingChild` (set in `doProduce`, read in `doConsume` to pick a 
child's projection) and `numOutputRowsTerm` (the once-per-stage metric term). 
This PR moves both fields to `ThreadLocal`, isolating the state to the single 
thread that runs a given `doCodeGen` pass.
    
    ### Why are the changes needed?
    
    A single `UnionExec` instance can have its whole-stage codegen driven by 
more than one thread at the same time: a reused exchange/subquery stage is 
generated concurrently with the main plan, and async subquery / 
dynamic-partition-pruning execution can overlap a driver-side `doCodeGen`. With 
the shared mutable field, a racing `doProduce` resets `currentEmittingChild` to 
`-1` while another thread is still inside `doConsume`, tripping:
    
    ```
    java.lang.IllegalArgumentException: requirement failed:
      UnionExec.doConsume invoked outside doProduce emission window
    ```
    
    This surfaced as a flaky `LogicalPlanTagInSparkPlanSuite.q2` failure (q2 
contains a `UNION`, and union fusion is enabled by default). Each `doCodeGen` 
pass is itself single-threaded (`produce` -> `doConsume` run inline on one 
thread), so a `ThreadLocal` isolates the state per pass without the 
cross-thread race, while preserving the existing per-stage semantics (the 
metric term is still computed once per pass).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. It removes an intermittent internal code-generation failure; the 
generated code and query results are unchanged.
    
    ### How was this patch tested?
    
    Added a `UnionCodegenSuite` test, "SPARK-57196: concurrent codegen of a 
shared UnionExec stage is thread-safe", that drives `doCodeGen()` on one shared 
fused `UnionExec` stage from 8 threads. It reproduces the "outside doProduce 
emission window" failure on the unpatched code and passes with this fix. Also 
verified the full `UnionCodegenSuite` (43 tests), its ANSI/AQE variants, and 
`LogicalPlanTagInSparkPlanSuite` q2 all pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (Opus 4.8)
    
    Closes #56252 from gengliangwang/spark-union-codegen-threadsafe.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 694c848a86ad12000c8b1f3f701f6c3c4fc8535a)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../sql/execution/basicPhysicalOperators.scala     | 46 ++++++++++++++------
 .../spark/sql/execution/UnionCodegenSuite.scala    | 50 ++++++++++++++++++++++
 2 files changed, 84 insertions(+), 12 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 9ed9c312a4be..81c048022c79 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -1017,16 +1017,38 @@ case class UnionExec(children: Seq[SparkPlan]) extends 
SparkPlan with CodegenSup
 
   override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(unionedInputRDD)
 
-  // Set in `doProduce`, read in `doConsume` during single-threaded code
-  // emission. `numOutputRowsTerm` is registered once per stage so the
-  // metric appears in `references[]` exactly once instead of once per
-  // child. `currentEmittingChild` tells `doConsume` which child's
-  // projection to bind.
-  @transient private var numOutputRowsTerm: String = _
-  @transient private var currentEmittingChild: Int = -1
+  // Per-emission codegen state, set in `doProduce` and read in `doConsume`.
+  // `numOutputRowsTerm` is registered once per stage so the metric appears in
+  // `references[]` exactly once instead of once per child; 
`currentEmittingChild`
+  // tells `doConsume` which child's projection to bind.
+  //
+  // A single `UnionExec` instance can have its codegen driven by more than one
+  // thread at the same time: a reused exchange/subquery stage is generated
+  // concurrently with the main plan, and async subquery / dynamic-pruning
+  // execution can overlap a driver-side `doCodeGen`. A plain field would let a
+  // racing `doProduce` reset `currentEmittingChild` to -1 while another thread
+  // is still in `doConsume`. Each `doCodeGen` pass is itself single-threaded
+  // (`produce` -> `doConsume` run inline on one thread), so a `ThreadLocal`
+  // isolates the state per pass without that cross-thread race.
+  //
+  // This state is valid only for the duration of one `doCodeGen` pass, not for
+  // the lifetime of a thread (much like the per-pass fields on 
`CodegenContext`,
+  // e.g. `currentPartitionIndexVar`, which `doProduce` saves and restores just
+  // below). `ThreadLocal` is correct because per-pass and per-thread coincide
+  // here: a pass runs inline on one thread and passes never nest on a thread.
+  // We keep it in a `ThreadLocal` rather than routing it through `ctx` because
+  // `CodegenContext` has no general-purpose per-pass attribute map; threading 
it
+  // through `ctx` would mean adding `UnionExec`-specific fields to a class 
shared
+  // by every operator. The `ThreadLocal` keeps this state local to the node 
that
+  // needs it. Resetting `currentEmittingChild` to -1 at the end of `doProduce`
+  // also guards against a stale value being read by a later, unrelated pass
+  // that reuses the same pooled thread.
+  @transient private lazy val numOutputRowsTerm = new ThreadLocal[String]
+  @transient private lazy val currentEmittingChild: ThreadLocal[Int] =
+    ThreadLocal.withInitial(() => -1)
 
   override protected def doProduce(ctx: CodegenContext): String = {
-    numOutputRowsTerm = metricTerm(ctx, "numOutputRows")
+    numOutputRowsTerm.set(metricTerm(ctx, "numOutputRows"))
 
     // For each partition of the unioned RDD, record its owning child and its
     // index within that child's RDD. Read both fields directly off the
@@ -1061,7 +1083,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends 
SparkPlan with CodegenSup
     val savedPartIdxVar = ctx.currentPartitionIndexVar
     ctx.currentPartitionIndexVar = s"((int[]) $p2lRef)[partitionIndex]"
     val cases = children.zipWithIndex.map { case (c, i) =>
-      currentEmittingChild = i
+      currentEmittingChild.set(i)
       val producedCode = c.asInstanceOf[CodegenSupport].produce(ctx, this)
       val helper = ctx.freshName("unionChildProcess")
       val qualifiedHelper = ctx.addNewFunction(helper,
@@ -1075,7 +1097,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends 
SparkPlan with CodegenSup
          |  break;
          |}""".stripMargin
     }
-    currentEmittingChild = -1
+    currentEmittingChild.set(-1)
     ctx.currentPartitionIndexVar = savedPartIdxVar
 
     s"""
@@ -1091,7 +1113,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends 
SparkPlan with CodegenSup
 
   override def doConsume(
       ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
-    val i = currentEmittingChild
+    val i = currentEmittingChild.get
     require(i >= 0, "UnionExec.doConsume invoked outside doProduce emission 
window")
     // Route BoundReference reads through `currentVars` (the incoming row is
     // delivered as variables under WSCG, not via ctx.INPUT_ROW).
@@ -1101,7 +1123,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends 
SparkPlan with CodegenSup
     val projectedExprCodes = bound.map(_.genCode(ctx))
 
     s"""
-       |$numOutputRowsTerm.add(1L);
+       |${numOutputRowsTerm.get}.add(1L);
        |${consume(ctx, projectedExprCodes)}
      """.stripMargin
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
index 8cc35ed43a10..6af286b610e2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
+
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
@@ -653,6 +655,54 @@ class UnionCodegenSuite extends SharedSparkSession {
     df.collect()
     assertFlagParity(() => a.union(b).orderBy("id"))
   }
+
+  test("SPARK-57196: concurrent codegen of a shared UnionExec stage is 
thread-safe") {
+    // A single `UnionExec` instance can have its whole-stage codegen driven by
+    // more than one thread at a time: a reused exchange/subquery is generated
+    // concurrently with the main plan, and async subquery/DPP execution can
+    // overlap a driver-side `doCodeGen`. The fusion path kept per-emission 
state
+    // (`currentEmittingChild`) in a mutable field on the shared instance, so a
+    // racing `doProduce` could reset it to -1 while another thread was still 
in
+    // `doConsume`, tripping the "UnionExec.doConsume invoked outside doProduce
+    // emission window" requirement. Generating the same fused stage from many
+    // threads reproduces the race.
+    val df = rangeDF(100).union(rangeDF(100)).filter(col("id") > 0)
+    assert(unionInsideWSCG(df))
+    val wscg = df.queryExecution.executedPlan.collectFirst {
+      case w: WholeStageCodegenExec if 
w.find(_.isInstanceOf[UnionExec]).isDefined => w
+    }.getOrElse(fail("expected a fused UnionExec stage"))
+
+    val numThreads = 8
+    val iterations = 200
+    val pool = Executors.newFixedThreadPool(numThreads)
+    val errors = java.util.Collections.synchronizedList(new 
java.util.ArrayList[Throwable]())
+    try {
+      val startLatch = new CountDownLatch(1)
+      val futures = (0 until numThreads).map { _ =>
+        pool.submit(new Runnable {
+          override def run(): Unit = {
+            startLatch.await()
+            var n = 0
+            while (n < iterations) {
+              try {
+                wscg.doCodeGen()
+              } catch {
+                case t: Throwable => errors.add(t)
+              }
+              n += 1
+            }
+          }
+        })
+      }
+      startLatch.countDown()
+      futures.foreach(_.get(60, TimeUnit.SECONDS))
+    } finally {
+      pool.shutdownNow()
+    }
+    assert(errors.isEmpty,
+      "concurrent doCodeGen on a shared UnionExec stage raced:\n" +
+        errors.toArray.map(_.toString).mkString("\n"))
+  }
 }
 
 /** Runs [[UnionCodegenSuite]] with ANSI mode enabled. */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to