This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new e3ac0d3d6819 [SPARK-57196][SQL] Make UnionExec whole-stage codegen
thread-safe
e3ac0d3d6819 is described below
commit e3ac0d3d6819f828fc0d88d01b134e9932eb06ca
Author: Gengliang Wang <[email protected]>
AuthorDate: Wed Jun 3 08:14:54 2026 -0700
[SPARK-57196][SQL] Make UnionExec whole-stage codegen thread-safe
### What changes were proposed in this pull request?
`UnionExec` whole-stage codegen fusion (SPARK-56482) kept per-emission
codegen state in mutable instance fields on the plan node:
`currentEmittingChild` (set in `doProduce`, read in `doConsume` to pick a
child's projection) and `numOutputRowsTerm` (the once-per-stage metric term).
This PR moves both fields to `ThreadLocal`, isolating the state to the single
thread that runs a given `doCodeGen` pass.
### Why are the changes needed?
A single `UnionExec` instance can have its whole-stage codegen driven by
more than one thread at the same time: a reused exchange/subquery stage is
generated concurrently with the main plan, and async subquery /
dynamic-partition-pruning execution can overlap a driver-side `doCodeGen`. With
the shared mutable field, a racing `doProduce` resets `currentEmittingChild` to
`-1` while another thread is still inside `doConsume`, tripping:
```
java.lang.IllegalArgumentException: requirement failed:
UnionExec.doConsume invoked outside doProduce emission window
```
This surfaced as a flaky `LogicalPlanTagInSparkPlanSuite.q2` failure (q2
contains a `UNION`, and union fusion is enabled by default). Each `doCodeGen`
pass is itself single-threaded (`produce` -> `doConsume` run inline on one
thread), so a `ThreadLocal` isolates the state per pass without the
cross-thread race, while preserving the existing per-stage semantics (the
metric term is still computed once per pass).
### Does this PR introduce _any_ user-facing change?
No. It removes an intermittent internal code-generation failure; the
generated code and query results are unchanged.
### How was this patch tested?
Added a `UnionCodegenSuite` test, "SPARK-57196: concurrent codegen of a
shared UnionExec stage is thread-safe", that drives `doCodeGen()` on one shared
fused `UnionExec` stage from 8 threads. It reproduces the "outside doProduce
emission window" failure on the unpatched code and passes with this fix. Also
verified the full `UnionCodegenSuite` (43 tests), its ANSI/AQE variants, and
`LogicalPlanTagInSparkPlanSuite` q2 all pass.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Opus 4.8)
Closes #56252 from gengliangwang/spark-union-codegen-threadsafe.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 694c848a86ad12000c8b1f3f701f6c3c4fc8535a)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/execution/basicPhysicalOperators.scala | 46 ++++++++++++++------
.../spark/sql/execution/UnionCodegenSuite.scala | 50 ++++++++++++++++++++++
2 files changed, 84 insertions(+), 12 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 9ed9c312a4be..81c048022c79 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -1017,16 +1017,38 @@ case class UnionExec(children: Seq[SparkPlan]) extends
SparkPlan with CodegenSup
override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(unionedInputRDD)
- // Set in `doProduce`, read in `doConsume` during single-threaded code
- // emission. `numOutputRowsTerm` is registered once per stage so the
- // metric appears in `references[]` exactly once instead of once per
- // child. `currentEmittingChild` tells `doConsume` which child's
- // projection to bind.
- @transient private var numOutputRowsTerm: String = _
- @transient private var currentEmittingChild: Int = -1
+ // Per-emission codegen state, set in `doProduce` and read in `doConsume`.
+ // `numOutputRowsTerm` is registered once per stage so the metric appears in
+ // `references[]` exactly once instead of once per child;
`currentEmittingChild`
+ // tells `doConsume` which child's projection to bind.
+ //
+ // A single `UnionExec` instance can have its codegen driven by more than one
+ // thread at the same time: a reused exchange/subquery stage is generated
+ // concurrently with the main plan, and async subquery / dynamic-pruning
+ // execution can overlap a driver-side `doCodeGen`. A plain field would let a
+ // racing `doProduce` reset `currentEmittingChild` to -1 while another thread
+ // is still in `doConsume`. Each `doCodeGen` pass is itself single-threaded
+ // (`produce` -> `doConsume` run inline on one thread), so a `ThreadLocal`
+ // isolates the state per pass without that cross-thread race.
+ //
+ // This state is valid only for the duration of one `doCodeGen` pass, not for
+ // the lifetime of a thread (much like the per-pass fields on
`CodegenContext`,
+ // e.g. `currentPartitionIndexVar`, which `doProduce` saves and restores just
+ // below). `ThreadLocal` is correct because per-pass and per-thread coincide
+ // here: a pass runs inline on one thread and passes never nest on a thread.
+ // We keep it in a `ThreadLocal` rather than routing it through `ctx` because
+ // `CodegenContext` has no general-purpose per-pass attribute map; threading
it
+ // through `ctx` would mean adding `UnionExec`-specific fields to a class
shared
+ // by every operator. The `ThreadLocal` keeps this state local to the node
that
+ // needs it. Resetting `currentEmittingChild` to -1 at the end of `doProduce`
+ // also guards against a stale value being read by a later, unrelated pass
+ // that reuses the same pooled thread.
+ @transient private lazy val numOutputRowsTerm = new ThreadLocal[String]
+ @transient private lazy val currentEmittingChild: ThreadLocal[Int] =
+ ThreadLocal.withInitial(() => -1)
override protected def doProduce(ctx: CodegenContext): String = {
- numOutputRowsTerm = metricTerm(ctx, "numOutputRows")
+ numOutputRowsTerm.set(metricTerm(ctx, "numOutputRows"))
// For each partition of the unioned RDD, record its owning child and its
// index within that child's RDD. Read both fields directly off the
@@ -1061,7 +1083,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends
SparkPlan with CodegenSup
val savedPartIdxVar = ctx.currentPartitionIndexVar
ctx.currentPartitionIndexVar = s"((int[]) $p2lRef)[partitionIndex]"
val cases = children.zipWithIndex.map { case (c, i) =>
- currentEmittingChild = i
+ currentEmittingChild.set(i)
val producedCode = c.asInstanceOf[CodegenSupport].produce(ctx, this)
val helper = ctx.freshName("unionChildProcess")
val qualifiedHelper = ctx.addNewFunction(helper,
@@ -1075,7 +1097,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends
SparkPlan with CodegenSup
| break;
|}""".stripMargin
}
- currentEmittingChild = -1
+ currentEmittingChild.set(-1)
ctx.currentPartitionIndexVar = savedPartIdxVar
s"""
@@ -1091,7 +1113,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends
SparkPlan with CodegenSup
override def doConsume(
ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- val i = currentEmittingChild
+ val i = currentEmittingChild.get
require(i >= 0, "UnionExec.doConsume invoked outside doProduce emission
window")
// Route BoundReference reads through `currentVars` (the incoming row is
// delivered as variables under WSCG, not via ctx.INPUT_ROW).
@@ -1101,7 +1123,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends
SparkPlan with CodegenSup
val projectedExprCodes = bound.map(_.genCode(ctx))
s"""
- |$numOutputRowsTerm.add(1L);
+ |${numOutputRowsTerm.get}.add(1L);
|${consume(ctx, projectedExprCodes)}
""".stripMargin
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
index 8cc35ed43a10..6af286b610e2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
+
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@@ -653,6 +655,54 @@ class UnionCodegenSuite extends SharedSparkSession {
df.collect()
assertFlagParity(() => a.union(b).orderBy("id"))
}
+
+ test("SPARK-57196: concurrent codegen of a shared UnionExec stage is
thread-safe") {
+ // A single `UnionExec` instance can have its whole-stage codegen driven by
+ // more than one thread at a time: a reused exchange/subquery is generated
+ // concurrently with the main plan, and async subquery/DPP execution can
+ // overlap a driver-side `doCodeGen`. The fusion path kept per-emission
state
+ // (`currentEmittingChild`) in a mutable field on the shared instance, so a
+ // racing `doProduce` could reset it to -1 while another thread was still
in
+ // `doConsume`, tripping the "UnionExec.doConsume invoked outside doProduce
+ // emission window" requirement. Generating the same fused stage from many
+ // threads reproduces the race.
+ val df = rangeDF(100).union(rangeDF(100)).filter(col("id") > 0)
+ assert(unionInsideWSCG(df))
+ val wscg = df.queryExecution.executedPlan.collectFirst {
+ case w: WholeStageCodegenExec if
w.find(_.isInstanceOf[UnionExec]).isDefined => w
+ }.getOrElse(fail("expected a fused UnionExec stage"))
+
+ val numThreads = 8
+ val iterations = 200
+ val pool = Executors.newFixedThreadPool(numThreads)
+ val errors = java.util.Collections.synchronizedList(new
java.util.ArrayList[Throwable]())
+ try {
+ val startLatch = new CountDownLatch(1)
+ val futures = (0 until numThreads).map { _ =>
+ pool.submit(new Runnable {
+ override def run(): Unit = {
+ startLatch.await()
+ var n = 0
+ while (n < iterations) {
+ try {
+ wscg.doCodeGen()
+ } catch {
+ case t: Throwable => errors.add(t)
+ }
+ n += 1
+ }
+ }
+ })
+ }
+ startLatch.countDown()
+ futures.foreach(_.get(60, TimeUnit.SECONDS))
+ } finally {
+ pool.shutdownNow()
+ }
+ assert(errors.isEmpty,
+ "concurrent doCodeGen on a shared UnionExec stage raced:\n" +
+ errors.toArray.map(_.toString).mkString("\n"))
+ }
}
/** Runs [[UnionCodegenSuite]] with ANSI mode enabled. */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]