This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 297cce16b67b [SPARK-51272][CORE] Aborting instead of continuing 
partially completed indeterminate result stage at ResubmitFailedStages
297cce16b67b is described below

commit 297cce16b67ba1ef84181b5ce5413f441012a87c
Author: attilapiros <[email protected]>
AuthorDate: Mon May 19 13:55:45 2025 +0800

    [SPARK-51272][CORE] Aborting instead of continuing partially completed 
indeterminate result stage at ResubmitFailedStages
    
    ### What changes were proposed in this pull request?
    
    This PR aborts the indeterminate partially completed result stage instead 
of resubmitting it.
    
    ### Why are the changes needed?
    
    A result stage compared to shuffle map stage has more output and more 
intermediate state:
    - It can use a `FileOutputCommitter` where each task does a Hadoop task 
commit. In case of a re-submit this will lead to re-commit that Hadoop task 
(possibly with different content).
    - In case of JDBC write it can already inserted all rows of a partitions 
into the target schema.
    
    Ignoring the resubmit when a recalculation is needed would cause data 
corruption as the partial result is based on the previous indeterminate 
computation but continuing means finishing the stage with the new recomputed 
data.
    
    As long as rollback of a result stage is not supported 
(https://issues.apache.org/jira/browse/SPARK-25342) the best we can do when a 
recalculation is needed is aborting the stage.
    
    The existing code before this PR already tried to address a similar 
situation at the handling of `FetchFailed` when the fetch is coming from an 
indeterminate shuffle map stage: 
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L2178-L2182
    
    But this is not enough as a `FetchFailed` from a determinate stage can lead 
to an executor loss and a re-compute of the indeterminate parent of the result 
stage as shown in the attached unittest.
    
    Moreover the `ResubmitFailedStages` can be in race with a successful 
`CompletionEvent`. This is why this PR detects the partial execution at the 
re-submit of the indeterminate result stage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests are created to illustrate the situation above.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50630 from attilapiros/SPARK-51272_attila_3.
    
    Lead-authored-by: attilapiros <[email protected]>
    Co-authored-by: Mridul Muralidharan <mridulatgmail.com>
    Co-authored-by: Peter Toth <[email protected]>
    Co-authored-by: Attila Zsolt Piros 
<[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 7604f677d9280cb370071a304fb1a1b6ca047609)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 144 ++++++++++++--------
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 147 +++++++++++++++++----
 2 files changed, 212 insertions(+), 79 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index aee92ba928b4..baf0ed4df530 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1552,6 +1552,26 @@ private[spark] class DAGScheduler(
     // `findMissingPartitions()` returns all partitions every time.
     stage match {
       case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
+        // already executed at least once
+        if (sms.getNextAttemptId > 0) {
+          // While we previously validated possible rollbacks during the 
handling of a FetchFailure,
+          // where we were fetching from an indeterminate source map stages, 
this later check
+          // covers additional cases like recalculating an indeterminate stage 
after an executor
+          // loss. Moreover, because this check occurs later in the process, 
if a result stage task
+          // has successfully completed, we can detect this and abort the job, 
as rolling back a
+          // result stage is not possible.
+          val stagesToRollback = collectSucceedingStages(sms)
+          abortStageWithInvalidRollBack(stagesToRollback)
+          // stages which cannot be rolled back were aborted which leads to 
removing the
+          // the dependant job(s) from the active jobs set
+          val numActiveJobsWithStageAfterRollback =
+            activeJobs.count(job => stagesToRollback.contains(job.finalStage))
+          if (numActiveJobsWithStageAfterRollback == 0) {
+            logInfo(log"All jobs depending on the indeterminate stage " +
+              log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is 
not needed anymore.")
+            return
+          }
+        }
         
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
         sms.shuffleDep.newShuffleMergeState()
       case _ =>
@@ -2129,60 +2149,8 @@ private[spark] class DAGScheduler(
               // guaranteed to be determinate, so the input data of the 
reducers will not change
               // even if the map tasks are re-tried.
               if (mapStage.isIndeterminate) {
-                // It's a little tricky to find all the succeeding stages of 
`mapStage`, because
-                // each stage only know its parents not children. Here we 
traverse the stages from
-                // the leaf nodes (the result stages of active jobs), and 
rollback all the stages
-                // in the stage chains that connect to the `mapStage`. To 
speed up the stage
-                // traversing, we collect the stages to rollback first. If a 
stage needs to
-                // rollback, all its succeeding stages need to rollback to.
-                val stagesToRollback = HashSet[Stage](mapStage)
-
-                def collectStagesToRollback(stageChain: List[Stage]): Unit = {
-                  if (stagesToRollback.contains(stageChain.head)) {
-                    stageChain.drop(1).foreach(s => stagesToRollback += s)
-                  } else {
-                    stageChain.head.parents.foreach { s =>
-                      collectStagesToRollback(s :: stageChain)
-                    }
-                  }
-                }
-
-                def generateErrorMessage(stage: Stage): String = {
-                  "A shuffle map stage with indeterminate output was failed 
and retried. " +
-                    s"However, Spark cannot rollback the $stage to re-process 
the input data, " +
-                    "and has to fail this job. Please eliminate the 
indeterminacy by " +
-                    "checkpointing the RDD before repartition and try again."
-                }
-
-                activeJobs.foreach(job => 
collectStagesToRollback(job.finalStage :: Nil))
-
-                // The stages will be rolled back after checking
-                val rollingBackStages = HashSet[Stage](mapStage)
-                stagesToRollback.foreach {
-                  case mapStage: ShuffleMapStage =>
-                    val numMissingPartitions = 
mapStage.findMissingPartitions().length
-                    if (numMissingPartitions < mapStage.numTasks) {
-                      if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
-                        val reason = "A shuffle map stage with indeterminate 
output was failed " +
-                          "and retried. However, Spark can only do this while 
using the new " +
-                          "shuffle block fetching protocol. Please check the 
config " +
-                          "'spark.shuffle.useOldFetchProtocol', see more 
detail in " +
-                          "SPARK-27665 and SPARK-25341."
-                        abortStage(mapStage, reason, None)
-                      } else {
-                        rollingBackStages += mapStage
-                      }
-                    }
-
-                  case resultStage: ResultStage if 
resultStage.activeJob.isDefined =>
-                    val numMissingPartitions = 
resultStage.findMissingPartitions().length
-                    if (numMissingPartitions < resultStage.numTasks) {
-                      // TODO: support to rollback result tasks.
-                      abortStage(resultStage, 
generateErrorMessage(resultStage), None)
-                    }
-
-                  case _ =>
-                }
+                val stagesToRollback = collectSucceedingStages(mapStage)
+                val rollingBackStages = 
abortStageWithInvalidRollBack(stagesToRollback)
                 logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} 
with indeterminate output was failed, " +
                   log"we will roll back and rerun below stages which include 
itself and all its " +
                   log"indeterminate child stages: ${MDC(STAGES, 
rollingBackStages)}")
@@ -2346,6 +2314,74 @@ private[spark] class DAGScheduler(
     }
   }
 
+  private def collectSucceedingStages(mapStage: ShuffleMapStage): 
HashSet[Stage] = {
+    // TODO: perhaps materialize this if we are going to compute it often 
enough ?
+    // It's a little tricky to find all the succeeding stages of `mapStage`, 
because
+    // each stage only know its parents not children. Here we traverse the 
stages from
+    // the leaf nodes (the result stages of active jobs), and rollback all the 
stages
+    // in the stage chains that connect to the `mapStage`. To speed up the 
stage
+    // traversing, we collect the stages to rollback first. If a stage needs to
+    // rollback, all its succeeding stages need to rollback to.
+    val succeedingStages = HashSet[Stage](mapStage)
+
+    def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = {
+      if (succeedingStages.contains(stageChain.head)) {
+        stageChain.drop(1).foreach(s => succeedingStages += s)
+      } else {
+        stageChain.head.parents.foreach { s =>
+          collectSucceedingStagesInternal(s :: stageChain)
+        }
+      }
+    }
+    activeJobs.foreach(job => collectSucceedingStagesInternal(job.finalStage 
:: Nil))
+    succeedingStages
+  }
+
+  /**
+   * Abort stages where roll back is requested but cannot be completed.
+   *
+   * @param stagesToRollback stages to roll back
+   * @return Shuffle map stages which need and can be rolled back
+   */
+  private def abortStageWithInvalidRollBack(stagesToRollback: HashSet[Stage]): 
HashSet[Stage] = {
+
+    def generateErrorMessage(stage: Stage): String = {
+      "A shuffle map stage with indeterminate output was failed and retried. " 
+
+        s"However, Spark cannot rollback the $stage to re-process the input 
data, " +
+        "and has to fail this job. Please eliminate the indeterminacy by " +
+        "checkpointing the RDD before repartition and try again."
+    }
+
+    // The stages will be rolled back after checking
+    val rollingBackStages = HashSet[Stage]()
+    stagesToRollback.foreach {
+      case mapStage: ShuffleMapStage =>
+        if (mapStage.numAvailableOutputs > 0) {
+          if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
+            val reason = "A shuffle map stage with indeterminate output was 
failed " +
+              "and retried. However, Spark can only do this while using the 
new " +
+              "shuffle block fetching protocol. Please check the config " +
+              "'spark.shuffle.useOldFetchProtocol', see more detail in " +
+              "SPARK-27665 and SPARK-25341."
+            abortStage(mapStage, reason, None)
+          } else {
+            rollingBackStages += mapStage
+          }
+        }
+
+      case resultStage: ResultStage if resultStage.activeJob.isDefined =>
+        val numMissingPartitions = resultStage.findMissingPartitions().length
+        if (numMissingPartitions < resultStage.numTasks) {
+          // TODO: support to rollback result tasks.
+          abortStage(resultStage, generateErrorMessage(resultStage), None)
+        }
+
+      case _ =>
+    }
+
+    rollingBackStages
+  }
+
   /**
    * Whether executor is decommissioning or decommissioned.
    * Return true when:
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3e507df706ba..d4e90be7c66d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -18,11 +18,11 @@
 package org.apache.spark.scheduler
 
 import java.util.{ArrayList => JArrayList, Collections => JCollections, 
Properties}
-import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, 
TimeUnit}
+import java.util.concurrent.{CountDownLatch, Delayed, LinkedBlockingQueue, 
ScheduledFuture, TimeUnit}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}
 
 import scala.annotation.meta.param
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, 
Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 import scala.jdk.CollectionConverters._
 import scala.language.reflectiveCalls
 import scala.util.control.NonFatal
@@ -56,28 +56,31 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: 
DAGScheduler)
 
   dagScheduler.setEventProcessLoop(this)
 
-  private var isProcessing = false
-  private val eventQueue = new ListBuffer[DAGSchedulerEvent]()
-
+  private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]()
 
   override def post(event: DAGSchedulerEvent): Unit = {
-    if (isProcessing) {
-      // `DAGSchedulerEventProcessLoop` is guaranteed to process events 
sequentially. So we should
-      // buffer events for sequent processing later instead of processing them 
recursively.
-      eventQueue += event
-    } else {
-      try {
-        isProcessing = true
-        // Forward event to `onReceive` directly to avoid processing event 
asynchronously.
-        onReceive(event)
-      } catch {
-        case NonFatal(e) => onError(e)
-      } finally {
-        isProcessing = false
-      }
-      if (eventQueue.nonEmpty) {
-        post(eventQueue.remove(0))
-      }
+    // `DAGSchedulerEventProcessLoop` is guaranteed to process events 
sequentially in the main test
+    // thread similarly as it is done in production using the 
"dag-scheduler-event-loop".
+    // So we should buffer events for sequent processing later instead of 
executing them
+    // on thread calling post() (which might be the "dag-scheduler-message" 
thread for some
+    // events posted by the DAGScheduler itself)
+    eventQueue.put(event)
+  }
+
+  def runEvents(): Unit = {
+    var dagEvent = eventQueue.poll()
+    while (dagEvent != null) {
+      onReciveWithErrorHandler(dagEvent)
+      dagEvent = eventQueue.poll()
+    }
+  }
+
+  private def onReciveWithErrorHandler(event: DAGSchedulerEvent): Unit = {
+    try {
+      // Forward event to `onReceive` directly to avoid processing event 
asynchronously.
+      onReceive(event)
+    } catch {
+      case NonFatal(e) => onError(e)
     }
   }
 
@@ -306,7 +309,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   var broadcastManager: BroadcastManager = null
   var securityMgr: SecurityManager = null
   var scheduler: DAGScheduler = null
-  var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
+  var dagEventProcessLoopTester: DAGSchedulerEventProcessLoopTester = null
 
   /**
    * Set of cache locations to return from our mock BlockManagerMaster.
@@ -479,6 +482,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     // Ensure the initialization of various components
     sc
     dagEventProcessLoopTester.post(event)
+    dagEventProcessLoopTester.runEvents()
   }
 
   /**
@@ -1190,11 +1194,12 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   private def completeNextStageWithFetchFailure(
       stageId: Int,
       attemptIdx: Int,
-      shuffleDep: ShuffleDependency[_, _, _]): Unit = {
+      shuffleDep: ShuffleDependency[_, _, _],
+      srcHost: String = "hostA"): Unit = {
     val stageAttempt = taskSets.last
     checkStageId(stageId, attemptIdx, stageAttempt)
     complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, 
idx) =>
-      (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, 
idx, "ignored"), null)
+      (FetchFailed(makeBlockManagerId(srcHost), shuffleDep.shuffleId, 0L, 0, 
idx, "ignored"), null)
     }.toSeq)
   }
 
@@ -2251,6 +2256,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(completedStage === List(0, 1))
 
     Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+    dagEventProcessLoopTester.runEvents()
     // map stage resubmitted
     assert(scheduler.runningStages.size === 1)
     val mapStage = scheduler.runningStages.head
@@ -2286,6 +2292,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     sc.listenerBus.waitUntilEmpty()
 
     Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+    dagEventProcessLoopTester.runEvents()
     // map stage is running by resubmitted, result stage is waiting
     // map tasks and the origin result task 1.0 are running
     assert(scheduler.runningStages.size == 1, "Map stage should be running")
@@ -3125,6 +3132,92 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(countSubmittedMapStageAttempts() === 2)
   }
 
+  /**
+   * This function creates the following dependency graph:
+   *
+   * (determinate)        (indeterminate)
+   * shuffleMapRdd0       shuffleMapRDD1
+   *              \       /
+   *               \     /
+   *               finalRdd
+   *
+   * Both ShuffleMapRdds will be ShuffleMapStages with 2 partitions executed on
+   * hostA_exec and hostB_exec.
+   */
+  def constructMixedDeterminateDependencies():
+    (ShuffleDependency[_, _, _], ShuffleDependency[_, _, _]) = {
+    val numPartitions = 2
+    val shuffleMapRdd0 = new MyRDD(sc, numPartitions, Nil, indeterminate = 
false)
+    val shuffleDep0 = new ShuffleDependency(shuffleMapRdd0, new 
HashPartitioner(2))
+
+    val shuffleMapRdd1 =
+      new MyRDD(sc, numPartitions, Nil, tracker = mapOutputTracker, 
indeterminate = true)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(2))
+
+    val finalRdd =
+      new MyRDD(sc, numPartitions, List(shuffleDep0, shuffleDep1), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    // Finish the first shuffle map stage.
+    completeShuffleMapStageSuccessfully(0, 0, numPartitions, Seq("hostA", 
"hostB"))
+    completeShuffleMapStageSuccessfully(1, 0, numPartitions, Seq("hostA", 
"hostB"))
+    assert(mapOutputTracker.findMissingPartitions(0) === Some(Seq.empty))
+    assert(mapOutputTracker.findMissingPartitions(1) === Some(Seq.empty))
+
+    (shuffleDep0, shuffleDep1)
+  }
+
+  test("SPARK-51272: re-submit of an indeterminate stage without partial 
result can succeed") {
+    val shuffleDeps = constructMixedDeterminateDependencies()
+    val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage]
+
+    // the fetch failure is from the determinate shuffle map stage but this 
leads to
+    // executor lost and removing the shuffle files generated by the 
indeterminate stage too
+    completeNextStageWithFetchFailure(resultStage.id, 0, shuffleDeps._1, 
"hostA")
+
+    Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+    dagEventProcessLoopTester.runEvents()
+    assert(scheduler.runningStages.size === 2)
+    assert(scheduler.runningStages.forall(_.isInstanceOf[ShuffleMapStage]))
+
+    completeShuffleMapStageSuccessfully(0, 1, 2, Seq("hostA", "hostB"))
+    completeShuffleMapStageSuccessfully(1, 1, 2, Seq("hostA", "hostB"))
+    assert(scheduler.runningStages.size === 1)
+    assert(scheduler.runningStages.head === resultStage)
+    assert(resultStage.latestInfo.failureReason.isEmpty)
+
+    completeNextResultStageWithSuccess(resultStage.id, 1)
+  }
+
+  test("SPARK-51272: re-submit of an indeterminate stage with partial result 
will fail") {
+    val shuffleDeps = constructMixedDeterminateDependencies()
+    val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage]
+
+    runEvent(makeCompletionEvent(taskSets(2).tasks(0), Success, 42))
+    // the fetch failure is from the determinate shuffle map stage but this 
leads to
+    // executor lost and removing the shuffle files generated by the 
indeterminate stage too
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(1),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleDeps._1.shuffleId, 0L, 
0, 0, "ignored"),
+      null))
+
+    dagEventProcessLoopTester.runEvents()
+    // resubmission has not yet happened, so job is still running
+    assert(scheduler.activeJobs.nonEmpty)
+    Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+    dagEventProcessLoopTester.runEvents()
+
+    // all dependent jobs have been failed
+    assert(scheduler.runningStages.size === 0)
+    assert(scheduler.activeJobs.isEmpty)
+    assert(resultStage.latestInfo.failureReason.isDefined)
+    assert(resultStage.latestInfo.failureReason.get.
+      contains("A shuffle map stage with indeterminate output was failed and 
retried. " +
+        "However, Spark cannot rollback the ResultStage"))
+    assert(scheduler.activeJobs.isEmpty, "Aborting the stage aborts the job as 
well.")
+  }
+
   private def constructIndeterminateStageFetchFailed(): (Int, Int) = {
     val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
 
@@ -4884,6 +4977,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     // wait resubmit
     sc.listenerBus.waitUntilEmpty()
     Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+    dagEventProcessLoopTester.runEvents()
 
     // stage0 retry
     val stage0Retry = taskSets.filter(_.stageId == 1)
@@ -4984,6 +5078,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
       // the stages will now get resubmitted due to the failure
       Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+      dagEventProcessLoopTester.runEvents()
 
       // parent map stage resubmitted
       assert(scheduler.runningStages.size === 1)
@@ -5003,6 +5098,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
         result = MapStatus(BlockManagerId("hostF-exec1", "hostF", 12345),
           Array.fill[Long](2)(2), mapTaskId = taskIdCount)))
       Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+      dagEventProcessLoopTester.runEvents()
 
       // The retries should succeed
       sc.listenerBus.waitUntilEmpty()
@@ -5012,6 +5108,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
       // This will add 3 new stages.
       submit(reduceRdd, Array(0, 1))
       Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
+      dagEventProcessLoopTester.runEvents()
 
       // Only the last stage needs to execute, and those tasks - so completed 
stages should not
       // change.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to