This is an automated email from the ASF dual-hosted git repository.

LuciferYang pushed a commit to branch branch-4.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.2 by this push:
     new 6f70be67b02e [SPARK-56873][CORE] Fix potential race condition in 
bounded k-way merge in UnsafeExternalSorter
6f70be67b02e is described below

commit 6f70be67b02efad24a9f1773b8489f0e0bed8e0b
Author: Tengfei Huang <[email protected]>
AuthorDate: Wed May 20 00:01:18 2026 +0800

    [SPARK-56873][CORE] Fix potential race condition in bounded k-way merge in 
UnsafeExternalSorter
    
    ### What changes were proposed in this pull request?
    Enhancement of the bounded multi-round k-way merge in 
`UnsafeExternalSorter` added in 
[SPARK-56410](https://issues.apache.org/jira/browse/SPARK-56410). Avoid the 
potential races between `getSortedIterator` and `spill`.
    
    Currently the race is not reachable through any first-party OSS execution 
path, It IS reachable for users running third-party `SparkPlugin`s that 
register `MemoryConsumer`s driven from non-task threads. With the bounded merge 
feature enabled, those plugins can potentially hit the race.
    
    #### The race condition:
    Before this PR, the bounded branch of `getSortedIterator()` did roughly:
    ```java
    boundedMerger = new UnsafeSorterBoundedSpillMerger(...);          // step 1
    readingIterator = new SpillableIterator(...);                     // step 2 
-- volatile publish
    return boundedMerger.merge(spillWriters, readingIterator);        // step 3 
-- snapshot taken INSIDE merge()
    ```
    
    `UnsafeSorterBoundedSpillMerger.merge()` snapshots its inputs via `new 
ArrayList<>(spillWriters)` at the top — so the snapshot is taken in step 3, 
**after** `readingIterator` is published in step 2.
    
    Once `readingIterator` is non-null, a sibling consumer's 
`acquireExecutionMemory()` that picks this sorter as the spill victim is routed 
to `readingIterator.spill()`, which (a) appends a new writer to the live 
`spillWriters` list, and (b) rebinds `readingIterator.upstream` to read that 
same new file. If this happens between step 2 and step 3, the snapshot taken 
inside `merge()` includes the new writer, **and** the final merge round also 
pulls from `readingIterator` (whose upstream now [...]
    
    #### To fix the issue:
    Similar to the non-bounded merger code path, take the snapshot of the 
`spillWriters` before publishing the `readingIterator`. Two things in one move:
    - **Closes the race.** The snapshot is now taken strictly before 
`readingIterator` is published. A later `readingIterator.spill()` appends to 
the live `spillWriters` list only — the snapshot is locked in. The new spill 
file reaches the output exactly once, via `readingIterator.upstream`.
    - **Makes the invariant a named API.** The hardening invariant is no longer 
comment-defended convention in a long method; the field bundle 
(`BoundedMergerContext`) also gives the test a clean handle to drive each phase 
explicitly.
    
    ### Why are the changes needed?
    Fix the potential race condition, to align with existing pattern.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UTs added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Generated-by: Claude Code (Opus 4.7)
    
    Closes #55891 from ivoson/SPARK-56873.
    
    Authored-by: Tengfei Huang <[email protected]>
    Signed-off-by: yangjie01 <[email protected]>
    (cherry picked from commit aa79fed96a42586f60dcce2824dc7a0c76c3440a)
    Signed-off-by: yangjie01 <[email protected]>
---
 .../unsafe/sort/UnsafeExternalSorter.java          | 78 ++++++++++++++------
 .../sort/UnsafeSorterBoundedSpillMerger.java       |  7 +-
 .../unsafe/sort/UnsafeExternalSorterSuite.java     | 83 ++++++++++++++++++++++
 3 files changed, 145 insertions(+), 23 deletions(-)

diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index a40dab8a8dab..2a3678a6b94d 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -594,28 +594,8 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
       logger.info("Merging {} spill files using bounded merge with factor {}",
           MDC.of(LogKeys.NUM_SPILL_WRITERS, spillWriters.size()),
           MDC.of(LogKeys.MERGE_FACTOR, spillMergeFactor));
-
-      // This assignment is not inside synchronized(this), unlike the read in
-      // cleanupResources(). That is safe because all callers of 
cleanupResources()
-      // (the task completion listener, iterator-end cleanup from wrappers like
-      // UnsafeExternalRowSorter / UnsafeKVExternalSorter / SortExec, etc.) 
run on
-      // the task thread, sequentially with getSortedIterator(). The volatile 
modifier
-      // on boundedMerger provides memory visibility across any intervening
-      // synchronized blocks.
-      boundedMerger = new UnsafeSorterBoundedSpillMerger(
-          spillMergeFactor,
-          recordComparatorSupplier.get(),
-          prefixComparator,
-          blockManager,
-          serializerManager,
-          fileBufferSizeBytes);
-
-      UnsafeSorterIterator inMemIter = null;
-      if (inMemSorter != null) {
-        readingIterator = new 
SpillableIterator(inMemSorter.getSortedIterator());
-        inMemIter = readingIterator;
-      }
-      return boundedMerger.merge(spillWriters, inMemIter);
+      BoundedMergerContext ctx = prepareBoundedMerge();
+      return ctx.merger.merge(ctx.snapshot, ctx.inMemIter);
     } else {
       // Original single-round merge: open all spill readers at once
       logger.info("Merging {} spill files in single round",
@@ -633,6 +613,60 @@ public final class UnsafeExternalSorter extends 
MemoryConsumer {
     }
   }
 
+  @VisibleForTesting
+  static final class BoundedMergerContext {
+    final List<UnsafeSorterSpillWriter> snapshot;
+    @Nullable final SpillableIterator inMemIter;
+    final UnsafeSorterBoundedSpillMerger merger;
+
+    BoundedMergerContext(
+        List<UnsafeSorterSpillWriter> snapshot,
+        @Nullable SpillableIterator inMemIter,
+        UnsafeSorterBoundedSpillMerger merger) {
+      this.snapshot = snapshot;
+      this.inMemIter = inMemIter;
+      this.merger = merger;
+    }
+  }
+
+  @VisibleForTesting
+  BoundedMergerContext prepareBoundedMerge() {
+    // Snapshot MUST precede readingIterator publication. Once readingIterator 
is
+    // non-null, a sibling MemoryConsumer's spill request is routed via
+    // readingIterator.spill(), which appends a new writer to spillWriters AND 
rebinds
+    // readingIterator.upstream to that same file. A post-publication snapshot 
would
+    // then feed that file to BOTH the snapshot path and readingIterator -- 
duplicate
+    // records in the merged output. List.copyOf returns an unmodifiable list 
so any
+    // future code that mutates the snapshot (or aliases the live spillWriters 
field
+    // into the context and adds to it) fails fast.
+    final List<UnsafeSorterSpillWriter> snapshot = List.copyOf(spillWriters);
+
+    // The volatile fields published below -- boundedMerger and 
readingIterator -- are
+    // written without holding synchronized(this). Safe because all callers of
+    // getSortedIterator() and cleanupResources() (the task completion 
listener,
+    // iterator-end cleanup from wrappers like UnsafeExternalRowSorter /
+    // UnsafeKVExternalSorter / SortExec, etc.) run on the task thread, 
sequentially.
+    // The volatile modifier provides memory visibility to off-task-thread 
readers:
+    // sibling MemoryConsumer.spill() reads readingIterator, and 
cleanupResources()'s
+    // synchronized(this) read of boundedMerger crosses any intervening 
synchronized
+    // blocks.
+    final UnsafeSorterBoundedSpillMerger merger = new 
UnsafeSorterBoundedSpillMerger(
+        spillMergeFactor,
+        recordComparatorSupplier.get(),
+        prefixComparator,
+        blockManager,
+        serializerManager,
+        fileBufferSizeBytes);
+    boundedMerger = merger;
+
+    SpillableIterator inMemIter = null;
+    if (inMemSorter != null) {
+      readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+      inMemIter = readingIterator;
+    }
+    return new BoundedMergerContext(snapshot, inMemIter, merger);
+  }
+
   @VisibleForTesting boolean hasSpaceForAnotherRecord() {
     return inMemSorter.hasSpaceForAnotherRecord();
   }
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
index 1f389465a8b2..b844f9816bf3 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterBoundedSpillMerger.java
@@ -90,6 +90,11 @@ final class UnsafeSorterBoundedSpillMerger {
    * <p>If {@code inMemIterator} is non-null, it is included in the final 
merge round
    * (not spilled to disk in intermediate rounds).</p>
    *
+   * <p>This method does not mutate the input {@code spillWriters} list; 
intermediate
+   * rounds reassign a local variable to fresh lists. Callers are still 
responsible for
+   * passing a defensive snapshot if they need to protect against concurrent 
mutation
+   * of the underlying list (see {@link 
UnsafeExternalSorter#prepareBoundedMerge}).</p>
+   *
    * @param spillWriters the list of spill writers to merge
    * @param inMemIterator optional in-memory sorted iterator to include in the 
final merge
    * @return a sorted iterator over all records
@@ -98,7 +103,7 @@ final class UnsafeSorterBoundedSpillMerger {
       List<UnsafeSorterSpillWriter> spillWriters,
       @Nullable UnsafeSorterIterator inMemIterator) throws IOException {
 
-    List<UnsafeSorterSpillWriter> spillsToMerge = new 
ArrayList<>(spillWriters);
+    List<UnsafeSorterSpillWriter> spillsToMerge = spillWriters;
     int round = 0;
 
     while (spillsToMerge.size() > mergeFactor) {
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 9ce43d32c1b1..d59bcfc2bd13 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection.unsafe.sort;
 import java.io.File;
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.BitSet;
 import java.util.LinkedList;
 import java.util.UUID;
 
@@ -36,6 +37,8 @@ import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.internal.config.package$;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.MemoryMode;
 import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.JavaSerializer;
@@ -852,6 +855,86 @@ public class UnsafeExternalSorterSuite {
     assertSpillFilesWereCleanedUp();
   }
 
+  @Test
+  public void testBoundedMergeSnapshotIsolatedFromConcurrentSpill() throws 
Exception {
+    // Verifies the prepareBoundedMerge() seam contract: ctx.snapshot is a 
defensive
+    // copy frozen at prepare-time, isolated from any later mutation of the 
live
+    // spillWriters list. The test drives the worst-case scenario by direct 
sequencing:
+    // an external-trigger spill() (the route a sibling MemoryConsumer takes 
under
+    // memory pressure) appends a writer to live spillWriters AND rebinds
+    // readingIterator.upstream to read it -- the merger must consume that 
file exactly
+    // once via readingIterator, not twice via the snapshot.
+    final UnsafeExternalSorter sorter = newSorter();
+    sorter.setSpillMergeFactor(2);
+
+    final int numSpills = 4;
+    final int recordsPerSpill = 8;
+    final int totalSpilled = numSpills * recordsPerSpill;
+    final int inMemRecords = 5;
+    final int totalRecords = totalSpilled + inMemRecords;
+
+    // Build numSpills spills with disjoint, interleaved keys.
+    for (int spill = 0; spill < numSpills; spill++) {
+      for (int j = 0; j < recordsPerSpill; j++) {
+        insertNumber(sorter, spill + j * numSpills);
+      }
+      sorter.spill();
+    }
+    // Leave a few records in memory so readingIterator has unread data that a
+    // concurrent spill() can drain into a new spill file.
+    for (int j = 0; j < inMemRecords; j++) {
+      insertNumber(sorter, totalSpilled + j);
+    }
+
+    // Phase 1: snapshot + publish readingIterator (production order).
+    UnsafeExternalSorter.BoundedMergerContext ctx = 
sorter.prepareBoundedMerge();
+    assertNotNull(ctx.inMemIter,
+        "readingIterator should be published when inMemSorter has data");
+    final int snapshotSizeBefore = ctx.snapshot.size();
+    final int spillFilesBefore = spillFilesCreated.size();
+
+    // Phase 2: external-trigger spill. Routes through readingIterator.spill():
+    // appends a writer to the live spillWriters AND rebinds 
readingIterator.upstream.
+    final MemoryConsumer externalTrigger =
+        new MemoryConsumer(taskMemoryManager, MemoryMode.ON_HEAP) {
+          @Override
+          public long spill(long size, MemoryConsumer trigger) {
+            return 0;
+          }
+        };
+    long bytesSpilled = sorter.spill(Long.MAX_VALUE, externalTrigger);
+    assertTrue(bytesSpilled > 0L,
+        "external-trigger spill must fire to exercise the seam contract");
+    // Exactly one new spill file should have been produced by the 
external-trigger spill.
+    assertEquals(spillFilesBefore + 1, spillFilesCreated.size(),
+        "external-trigger spill should produce exactly one new spill file");
+    // Defensive-copy invariant: the post-spill snapshot is unchanged. A future
+    // refactor that aliases ctx.snapshot to the live spillWriters field 
instead of
+    // copying it would fail this assertion.
+    assertEquals(snapshotSizeBefore, ctx.snapshot.size(),
+        "ctx.snapshot must be isolated from live spillWriters mutation");
+
+    // Phase 3: merge using the frozen snapshot.
+    UnsafeSorterIterator iter = ctx.merger.merge(ctx.snapshot, ctx.inMemIter);
+
+    // Each input record must appear exactly once: no duplicates, no losses.
+    BitSet seen = new BitSet(totalRecords);
+    int count = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      int v = Platform.getInt(iter.getBaseObject(), iter.getBaseOffset());
+      assertTrue(v >= 0 && v < totalRecords, "record out of range: " + v);
+      assertFalse(seen.get(v), "duplicate record observed: " + v);
+      seen.set(v);
+      count++;
+    }
+    assertEquals(totalRecords, count, "wrong record count");
+    assertEquals(totalRecords, seen.cardinality(), "missing records");
+
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
   @Test
   public void testBoundedMergeWithDuplicateKeys() throws Exception {
     // Multiple spills contain identical keys. Verifies that all duplicates are


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to