This is an automated email from the ASF dual-hosted git repository.
LuciferYang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 1478984e707b [SPARK-56873][CORE] Fix potential race condition in
bounded k-way merge in UnsafeExternalSorter
1478984e707b is described below
commit 1478984e707bda128ecf3f79c7020c0ccfad6bd2
Author: Tengfei Huang <[email protected]>
AuthorDate: Wed May 20 00:01:18 2026 +0800
[SPARK-56873][CORE] Fix potential race condition in bounded k-way merge in
UnsafeExternalSorter
### What changes were proposed in this pull request?
Enhancement of the bounded multi-round k-way merge in
`UnsafeExternalSorter` added in
[SPARK-56410](https://issues.apache.org/jira/browse/SPARK-56410). Avoid the
potential races between `getSortedIterator` and `spill`.
Currently the race is not reachable through any first-party OSS execution
path, It IS reachable for users running third-party `SparkPlugin`s that
register `MemoryConsumer`s driven from non-task threads. With the bounded merge
feature enabled, those plugins can potentially hit the race.
#### The race condition:
Before this PR, the bounded branch of `getSortedIterator()` did roughly:
```java
boundedMerger = new UnsafeSorterBoundedSpillMerger(...); // step 1
readingIterator = new SpillableIterator(...); // step 2
-- volatile publish
return boundedMerger.merge(spillWriters, readingIterator); // step 3
-- snapshot taken INSIDE merge()
```
`UnsafeSorterBoundedSpillMerger.merge()` snapshots its inputs via `new
ArrayList<>(spillWriters)` at the top — so the snapshot is taken in step 3,
**after** `readingIterator` is published in step 2.
Once `readingIterator` is non-null, a sibling consumer's
`acquireExecutionMemory()` that picks this sorter as the spill victim is routed
to `readingIterator.spill()`, which (a) appends a new writer to the live
`spillWriters` list, and (b) rebinds `readingIterator.upstream` to read that
same new file. If this happens between step 2 and step 3, the snapshot taken
inside `merge()` includes the new writer, **and** the final merge round also
pulls from `readingIterator` (whose upstream now [...]
#### To fix the issue:
Similar to the non-bounded merger code path, take the snapshot of the
`spillWriters` before publishing the `readingIterator`. Two things in one move:
- **Closes the race.** The snapshot is now taken strictly before
`readingIterator` is published. A later `readingIterator.spill()` appends to
the live `spillWriters` list only — the snapshot is locked in. The new spill
file reaches the output exactly once, via `readingIterator.upstream`.
- **Makes the invariant a named API.** The hardening invariant is no longer
comment-defended convention in a long method; the field bundle
(`BoundedMergerContext`) also gives the test a clean handle to drive each phase
explicitly.
### Why are the changes needed?
Fix the potential race condition, to align with existing pattern.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UTs added.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Opus 4.7)
Closes #55891 from ivoson/SPARK-56873.
Authored-by: Tengfei Huang <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
(cherry picked from commit aa79fed96a42586f60dcce2824dc7a0c76c3440a)
Signed-off-by: yangjie01 <[email protected]>
---
.../unsafe/sort/UnsafeExternalSorter.java | 78 ++++++++++++++------
.../sort/UnsafeSorterBoundedSpillMerger.java | 7 +-
.../unsafe/sort/UnsafeExternalSorterSuite.java | 83 ++++++++++++++++++++++
3 files changed, 145 insertions(+), 23 deletions(-)
diff --git
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index a40dab8a8dab..2a3678a6b94d 100644
---
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -594,28 +594,8 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
logger.info("Merging {} spill files using bounded merge with factor {}",
MDC.of(LogKeys.NUM_SPILL_WRITERS, spillWriters.size()),
MDC.of(LogKeys.MERGE_FACTOR, spillMergeFactor));
-
- // This assignment is not inside synchronized(this), unlike the read in
- // cleanupResources(). That is safe because all callers of
cleanupResources()
- // (the task completion listener, iterator-end cleanup from wrappers like
- // UnsafeExternalRowSorter / UnsafeKVExternalSorter / SortExec, etc.)
run on
- // the task thread, sequentially with getSortedIterator(). The volatile
modifier
- // on boundedMerger provides memory visibility across any intervening
- // synchronized blocks.
- boundedMerger = new UnsafeSorterBoundedSpillMerger(
- spillMergeFactor,
- recordComparatorSupplier.get(),
- prefixComparator,
- blockManager,
- serializerManager,
- fileBufferSizeBytes);
-
- UnsafeSorterIterator inMemIter = null;
- if (inMemSorter != null) {
- readingIterator = new
SpillableIterator(inMemSorter.getSortedIterator());
- inMemIter = readingIterator;
- }
- return boundedMerger.merge(spillWriters, inMemIter);
+ BoundedMergerContext ctx = prepareBoundedMerge();
+ return ctx.merger.merge(ctx.snapshot, ctx.inMemIter);
} else {
// Original single-round merge: open all spill readers at once
logger.info("Merging {} spill files in single round",
@@ -633,6 +613,60 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
}
}
+ @VisibleForTesting
+ static final class BoundedMergerContext {
+ final List<UnsafeSorterSpillWriter> snapshot;
+ @Nullable final SpillableIterator inMemIter;
+ final UnsafeSorterBoundedSpillMerger merger;
+
+ BoundedMergerContext(
+ List<UnsafeSorterSpillWriter> snapshot,
+ @Nullable SpillableIterator inMemIter,
+ UnsafeSorterBoundedSpillMerger merger) {
+ this.snapshot = snapshot;
+ this.inMemIter = inMemIter;
+ this.merger = merger;
+ }
+ }
+
+ @VisibleForTesting
+ BoundedMergerContext prepareBoundedMerge() {
+ // Snapshot MUST precede readingIterator publication. Once readingIterator
is
+ // non-null, a sibling MemoryConsumer's spill request is routed via
+ // readingIterator.spill(), which appends a new writer to spillWriters AND
rebinds
+ // readingIterator.upstream to that same file. A post-publication snapshot
would
+ // then feed that file to BOTH the snapshot path and readingIterator --
duplicate
+ // records in the merged output. List.copyOf returns an unmodifiable list
so any
+ // future code that mutates the snapshot (or aliases the live spillWriters
field
+ // into the context and adds to it) fails fast.
+ final List<UnsafeSorterSpillWriter> snapshot = List.copyOf(spillWriters);
+
+ // The volatile fields published below -- boundedMerger and
readingIterator -- are
+ // written without holding synchronized(this). Safe because all callers of
+ // getSortedIterator() and cleanupResources() (the task completion
listener,
+ // iterator-end cleanup from wrappers like UnsafeExternalRowSorter /
+ // UnsafeKVExternalSorter / SortExec, etc.) run on the task thread,
sequentially.
+ // The volatile modifier provides memory visibility to off-task-thread
readers:
+ // sibling MemoryConsumer.spill() reads readingIterator, and
cleanupResources()'s
+ // synchronized(this) read of boundedMerger crosses any intervening
synchronized
+ // blocks.
+ final UnsafeSorterBoundedSpillMerger merger = new
UnsafeSorterBoundedSpillMerger(
+ spillMergeFactor,
+ recordComparatorSupplier.get(),
+ prefixComparator,
+ blockManager,
+ serializerManager,
+ fileBufferSizeBytes);
+ boundedMerger = merger;
+
+ SpillableIterator inMemIter = null;
+ if (inMemSorter != null) {
+ readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+ inMemIter = readingIterator;
+ }
+ return new BoundedMergerContext(snapshot, inMemIter, merger);
+ }
+
@VisibleForTesting boolean hasSpaceForAnotherRecord() {
return inMemSorter.hasSpaceForAnotherRecord();
}
diff --git
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
index 1f389465a8b2..b844f9816bf3 100644
---
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
+++
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
@@ -90,6 +90,11 @@ final class UnsafeSorterBoundedSpillMerger {
* <p>If {@code inMemIterator} is non-null, it is included in the final
merge round
* (not spilled to disk in intermediate rounds).</p>
*
+ * <p>This method does not mutate the input {@code spillWriters} list;
intermediate
+ * rounds reassign a local variable to fresh lists. Callers are still
responsible for
+ * passing a defensive snapshot if they need to protect against concurrent
mutation
+ * of the underlying list (see {@link
UnsafeExternalSorter#prepareBoundedMerge}).</p>
+ *
* @param spillWriters the list of spill writers to merge
* @param inMemIterator optional in-memory sorted iterator to include in the
final merge
* @return a sorted iterator over all records
@@ -98,7 +103,7 @@ final class UnsafeSorterBoundedSpillMerger {
List<UnsafeSorterSpillWriter> spillWriters,
@Nullable UnsafeSorterIterator inMemIterator) throws IOException {
- List<UnsafeSorterSpillWriter> spillsToMerge = new
ArrayList<>(spillWriters);
+ List<UnsafeSorterSpillWriter> spillsToMerge = spillWriters;
int round = 0;
while (spillsToMerge.size() > mergeFactor) {
diff --git
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 9ce43d32c1b1..d59bcfc2bd13 100644
---
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
+import java.util.BitSet;
import java.util.LinkedList;
import java.util.UUID;
@@ -36,6 +37,8 @@ import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.internal.config.package$;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.JavaSerializer;
@@ -852,6 +855,86 @@ public class UnsafeExternalSorterSuite {
assertSpillFilesWereCleanedUp();
}
+ @Test
+ public void testBoundedMergeSnapshotIsolatedFromConcurrentSpill() throws
Exception {
+ // Verifies the prepareBoundedMerge() seam contract: ctx.snapshot is a
defensive
+ // copy frozen at prepare-time, isolated from any later mutation of the
live
+ // spillWriters list. The test drives the worst-case scenario by direct
sequencing:
+ // an external-trigger spill() (the route a sibling MemoryConsumer takes
under
+ // memory pressure) appends a writer to live spillWriters AND rebinds
+ // readingIterator.upstream to read it -- the merger must consume that
file exactly
+ // once via readingIterator, not twice via the snapshot.
+ final UnsafeExternalSorter sorter = newSorter();
+ sorter.setSpillMergeFactor(2);
+
+ final int numSpills = 4;
+ final int recordsPerSpill = 8;
+ final int totalSpilled = numSpills * recordsPerSpill;
+ final int inMemRecords = 5;
+ final int totalRecords = totalSpilled + inMemRecords;
+
+ // Build numSpills spills with disjoint, interleaved keys.
+ for (int spill = 0; spill < numSpills; spill++) {
+ for (int j = 0; j < recordsPerSpill; j++) {
+ insertNumber(sorter, spill + j * numSpills);
+ }
+ sorter.spill();
+ }
+ // Leave a few records in memory so readingIterator has unread data that a
+ // concurrent spill() can drain into a new spill file.
+ for (int j = 0; j < inMemRecords; j++) {
+ insertNumber(sorter, totalSpilled + j);
+ }
+
+ // Phase 1: snapshot + publish readingIterator (production order).
+ UnsafeExternalSorter.BoundedMergerContext ctx =
sorter.prepareBoundedMerge();
+ assertNotNull(ctx.inMemIter,
+ "readingIterator should be published when inMemSorter has data");
+ final int snapshotSizeBefore = ctx.snapshot.size();
+ final int spillFilesBefore = spillFilesCreated.size();
+
+ // Phase 2: external-trigger spill. Routes through readingIterator.spill():
+ // appends a writer to the live spillWriters AND rebinds
readingIterator.upstream.
+ final MemoryConsumer externalTrigger =
+ new MemoryConsumer(taskMemoryManager, MemoryMode.ON_HEAP) {
+ @Override
+ public long spill(long size, MemoryConsumer trigger) {
+ return 0;
+ }
+ };
+ long bytesSpilled = sorter.spill(Long.MAX_VALUE, externalTrigger);
+ assertTrue(bytesSpilled > 0L,
+ "external-trigger spill must fire to exercise the seam contract");
+ // Exactly one new spill file should have been produced by the
external-trigger spill.
+ assertEquals(spillFilesBefore + 1, spillFilesCreated.size(),
+ "external-trigger spill should produce exactly one new spill file");
+ // Defensive-copy invariant: the post-spill snapshot is unchanged. A future
+ // refactor that aliases ctx.snapshot to the live spillWriters field
instead of
+ // copying it would fail this assertion.
+ assertEquals(snapshotSizeBefore, ctx.snapshot.size(),
+ "ctx.snapshot must be isolated from live spillWriters mutation");
+
+ // Phase 3: merge using the frozen snapshot.
+ UnsafeSorterIterator iter = ctx.merger.merge(ctx.snapshot, ctx.inMemIter);
+
+ // Each input record must appear exactly once: no duplicates, no losses.
+ BitSet seen = new BitSet(totalRecords);
+ int count = 0;
+ while (iter.hasNext()) {
+ iter.loadNext();
+ int v = Platform.getInt(iter.getBaseObject(), iter.getBaseOffset());
+ assertTrue(v >= 0 && v < totalRecords, "record out of range: " + v);
+ assertFalse(seen.get(v), "duplicate record observed: " + v);
+ seen.set(v);
+ count++;
+ }
+ assertEquals(totalRecords, count, "wrong record count");
+ assertEquals(totalRecords, seen.cardinality(), "missing records");
+
+ sorter.cleanupResources();
+ assertSpillFilesWereCleanedUp();
+ }
+
@Test
public void testBoundedMergeWithDuplicateKeys() throws Exception {
// Multiple spills contain identical keys. Verifies that all duplicates are
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]