This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 922adad1e3cc [SPARK-53575][CORE] Retry entire consumer stages when
checksum mismatch detected for a retried shuffle map task
922adad1e3cc is described below
commit 922adad1e3cc4eae70a7d10f852bf585835ef21a
Author: Tengfei Huang <[email protected]>
AuthorDate: Mon Sep 29 17:53:37 2025 +0800
[SPARK-53575][CORE] Retry entire consumer stages when checksum mismatch
detected for a retried shuffle map task
### What changes were proposed in this pull request?
This PR proposes to retry all tasks of the consumer stages, when checksum
mismatches are detected on their producer stages. In the case that we can't
rollback and retry all tasks of a consumer stage, we will have to abort the
stage (thus the job).
How do we detect and handle nondeterministic before:
- Stages are labeled as indeterminate at planning time, prior to query
execution
- When a task completes and `FetchFailed` is detected, we will abort all
unrollbackable succeeding stages of the map stage, and resubmit failed stages.
- In `submitMissingTasks()`, if a stage itself is isIndeterminate, we will
call `unregisterAllMapAndMergeOutput()` and retry all tasks for stage.
How do we detect and handle nondeterministic now:
- During query execution, we keep track on the checksums produced by each
map task.
- When a task completes and checksum mismatch is detected, we will abort
unrollbackable succeeding stages of the stage with checksum mismatches. The
failed stages resubmission still happen in the same places as before.
- In `submitMissingTasks()`, if the parent of a stage has checksum
mismatches, we will call `unregisterAllMapAndMergeOutput()` and retry all tasks
for stage.
Note that (1) if a stage `isReliablyCheckpointed`, the consumer stages
don't need to have whole stage retry, and (2) when mismatches are detected for
a stage in a chain (e.g., the first stage in stage_i -> stage_i+1 -> stage_i+2
-> ...), the direct consumer (e.g., stage_i+1) of the stage will have a whole
stage retry, and an indirect consumer (e.g., stage_i+2) will have a whole stage
retry when its parent detects checksum mismatches.
### Why are the changes needed?
Handle nondeterministic issues caused by the retry of shuffle map task.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UTs added.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52336 from ivoson/SPARK-53575.
Authored-by: Tengfei Huang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../main/scala/org/apache/spark/Dependency.scala | 3 +-
.../scala/org/apache/spark/MapOutputTracker.scala | 10 +-
core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +-
.../org/apache/spark/scheduler/DAGScheduler.scala | 101 +++++---
.../scala/org/apache/spark/scheduler/Stage.scala | 22 ++
.../apache/spark/scheduler/DAGSchedulerSuite.scala | 271 ++++++++++++++++++++-
.../org/apache/spark/sql/internal/SQLConf.scala | 11 +
.../execution/exchange/ShuffleExchangeExec.scala | 9 +-
.../apache/spark/sql/MapStatusEndToEndSuite.scala | 43 ++--
9 files changed, 407 insertions(+), 65 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala
b/core/src/main/scala/org/apache/spark/Dependency.scala
index 93a2bbe25157..c436025e06bb 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -89,7 +89,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new
ShuffleWriteProcessor,
- val rowBasedChecksums: Array[RowBasedChecksum] =
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS)
+ val rowBasedChecksums: Array[RowBasedChecksum] =
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
+ val checksumMismatchFullRetryEnabled: Boolean = false)
extends Dependency[Product2[K, V]] with Logging {
def this(
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 3f823b60156a..334eb832c4c2 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -165,9 +165,11 @@ private class ShuffleStatus(
/**
* Register a map output. If there is already a registered location for the
map output then it
- * will be replaced by the new location.
+ * will be replaced by the new location. Returns true if the checksum in the
new MapStatus is
+ * different from a previous registered MapStatus. Otherwise, returns false.
*/
- def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
+ def addMapOutput(mapIndex: Int, status: MapStatus): Boolean = withWriteLock {
+ var isChecksumMismatch: Boolean = false
val currentMapStatus = mapStatuses(mapIndex)
if (currentMapStatus == null) {
_numAvailableMapOutputs += 1
@@ -183,9 +185,11 @@ private class ShuffleStatus(
logInfo(s"Checksum of map output changes from ${preStatus.checksumValue}
to " +
s"${status.checksumValue} for task ${status.mapId}.")
checksumMismatchIndices.add(mapIndex)
+ isChecksumMismatch = true
}
mapStatuses(mapIndex) = status
mapIdToMapIndex(status.mapId) = mapIndex
+ isChecksumMismatch
}
/**
@@ -853,7 +857,7 @@ private[spark] class MapOutputTrackerMaster(
}
}
- def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus):
Unit = {
+ def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus):
Boolean = {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 117b2925710d..d1408ee774ce 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1773,7 +1773,7 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD is reliably checkpointed and materialized.
*/
- private[rdd] def isReliablyCheckpointed: Boolean = {
+ private[spark] def isReliablyCheckpointed: Boolean = {
checkpointData match {
case Some(reliable: ReliableRDDCheckpointData[_]) if
reliable.isCheckpointed => true
case _ => false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 30eb49b0c079..3b719a2c7d24 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1551,29 +1551,46 @@ private[spark] class DAGScheduler(
// The operation here can make sure for the partially completed
intermediate stage,
// `findMissingPartitions()` returns all partitions every time.
stage match {
- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
- // already executed at least once
- if (sms.getNextAttemptId > 0) {
- // While we previously validated possible rollbacks during the
handling of a FetchFailure,
- // where we were fetching from an indeterminate source map stages,
this later check
- // covers additional cases like recalculating an indeterminate stage
after an executor
- // loss. Moreover, because this check occurs later in the process,
if a result stage task
- // has successfully completed, we can detect this and abort the job,
as rolling back a
- // result stage is not possible.
- val stagesToRollback = collectSucceedingStages(sms)
- abortStageWithInvalidRollBack(stagesToRollback)
- // stages which cannot be rolled back were aborted which leads to
removing the
- // the dependant job(s) from the active jobs set
- val numActiveJobsWithStageAfterRollback =
- activeJobs.count(job => stagesToRollback.contains(job.finalStage))
- if (numActiveJobsWithStageAfterRollback == 0) {
- logInfo(log"All jobs depending on the indeterminate stage " +
- log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is
not needed anymore.")
- return
+ case sms: ShuffleMapStage if !sms.isAvailable =>
+ val needFullStageRetry = if
(sms.shuffleDep.checksumMismatchFullRetryEnabled) {
+ // When the parents of this stage are indeterminate (e.g., some
parents are not
+ // checkpointed and checksum mismatches are detected), the output
data of the parents
+ // may have changed due to task retries. For correctness reason, we
need to
+ // retry all tasks of the current stage. The legacy way of using
current stage's
+ // deterministic level to trigger full stage retry is not accurate.
+ stage.isParentIndeterminate
+ } else {
+ if (stage.isIndeterminate) {
+ // already executed at least once
+ if (sms.getNextAttemptId > 0) {
+ // While we previously validated possible rollbacks during the
handling of a FetchFailure,
+ // where we were fetching from an indeterminate source map
stages, this later check
+ // covers additional cases like recalculating an indeterminate
stage after an executor
+ // loss. Moreover, because this check occurs later in the
process, if a result stage task
+ // has successfully completed, we can detect this and abort the
job, as rolling back a
+ // result stage is not possible.
+ val stagesToRollback = collectSucceedingStages(sms)
+ abortStageWithInvalidRollBack(stagesToRollback)
+ // stages which cannot be rolled back were aborted which leads
to removing the
+ // the dependant job(s) from the active jobs set
+ val numActiveJobsWithStageAfterRollback =
+ activeJobs.count(job =>
stagesToRollback.contains(job.finalStage))
+ if (numActiveJobsWithStageAfterRollback == 0) {
+ logInfo(log"All jobs depending on the indeterminate stage " +
+ log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage
is not needed anymore.")
+ return
+ }
+ }
+ true
+ } else {
+ false
}
}
-
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
- sms.shuffleDep.newShuffleMergeState()
+
+ if (needFullStageRetry) {
+
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+ sms.shuffleDep.newShuffleMergeState()
+ }
case _ =>
}
@@ -1886,6 +1903,20 @@ private[spark] class DAGScheduler(
}
}
+ /**
+ * If a map stage is non-deterministic, the map tasks of the stage may
return different result
+ * when re-try. To make sure data correctness, we need to re-try all the
tasks of its succeeding
+ * stages, as the input data may be changed after the map tasks are
re-tried. For stages where
+ * rollback and retry all tasks are not possible, we will need to abort the
stages.
+ */
+ private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage):
Unit = {
+ val stagesToRollback = collectSucceedingStages(mapStage)
+ val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
+ logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with
indeterminate output " +
+ log"was failed, we will roll back and rerun below stages which include
itself and all its " +
+ log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
+ }
+
/**
* Responds to a task finishing. This is called inside the event loop so it
assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end
event from outside.
@@ -2022,8 +2053,26 @@ private[spark] class DAGScheduler(
// The epoch of the task is acceptable (i.e., the task was
launched after the most
// recent failure we're aware of for the executor), so mark
the task's output as
// available.
- mapOutputTracker.registerMapOutput(
+ val isChecksumMismatched = mapOutputTracker.registerMapOutput(
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
+ if (isChecksumMismatched) {
+ shuffleStage.isChecksumMismatched = isChecksumMismatched
+ // There could be multiple checksum mismatches detected for
a single stage attempt.
+ // We check for stage abortion once and only once when we
first detect checksum
+ // mismatch for each stage attempt. For example, assume that
we have
+ // stage1 -> stage2, and we encounter checksum mismatch
during the retry of stage1.
+ // In this case, we need to call abortUnrollbackableStages()
for the succeeding
+ // stages. Assume that when stage2 is retried, some tasks
finish and some tasks
+ // failed again with FetchFailed. In case that we encounter
checksum mismatch again
+ // during the retry of stage1, we need to call
abortUnrollbackableStages() again.
+ if (shuffleStage.maxChecksumMismatchedId <
smt.stageAttemptId) {
+ shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
+ if
(shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
+ && shuffleStage.isStageIndeterminate) {
+ abortUnrollbackableStages(shuffleStage)
+ }
+ }
+ }
}
} else {
logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an
older attempt of indeterminate stage")
@@ -2148,12 +2197,8 @@ private[spark] class DAGScheduler(
// Note that, if map stage is UNORDERED, we are fine. The
shuffle partitioner is
// guaranteed to be determinate, so the input data of the
reducers will not change
// even if the map tasks are re-tried.
- if (mapStage.isIndeterminate) {
- val stagesToRollback = collectSucceedingStages(mapStage)
- val rollingBackStages =
abortStageWithInvalidRollBack(stagesToRollback)
- logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)}
with indeterminate output was failed, " +
- log"we will roll back and rerun below stages which include
itself and all its " +
- log"indeterminate child stages: ${MDC(STAGES,
rollingBackStages)}")
+ if (mapStage.isIndeterminate &&
!mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
+ abortUnrollbackableStages(mapStage)
}
// We expect one executor failure to trigger many FetchFailures
in rapid succession,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index f35beafd8748..9bf604e9a83c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -72,6 +72,18 @@ private[scheduler] abstract class Stage(
private var nextAttemptId: Int = 0
private[scheduler] def getNextAttemptId: Int = nextAttemptId
+ /**
+ * Whether checksum mismatches have been detected across different attempt
of the stage, where
+ * checksum mismatches typically indicates that different stage attempts
have produced different
+ * data.
+ */
+ private[scheduler] var isChecksumMismatched: Boolean = false
+
+ /**
+ * The maximum of task attempt id where checksum mismatches are detected.
+ */
+ private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId
+
val name: String = callSite.shortForm
val details: String = callSite.longForm
@@ -131,4 +143,14 @@ private[scheduler] abstract class Stage(
def isIndeterminate: Boolean = {
rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
}
+
+ // Returns true if any parents of this stage are indeterminate.
+ def isParentIndeterminate: Boolean = {
+ parents.exists(_.isStageIndeterminate)
+ }
+
+ // Returns true if the stage itself is indeterminate.
+ def isStageIndeterminate: Boolean = {
+ !rdd.isReliablyCheckpointed && isChecksumMismatched
+ }
}
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 1ada81cbdd0e..c20866fda0a3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -3415,6 +3415,19 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assertDataStructuresEmpty()
}
+ private def checkAndCompleteRetryStage(
+ taskSetIndex: Int,
+ stageId: Int,
+ shuffleId: Int,
+ numTasks: Int = 2,
+ checksumVal: Long = 0): Unit = {
+ assert(taskSets(taskSetIndex).stageId == stageId)
+ assert(taskSets(taskSetIndex).stageAttemptId == 1)
+ assert(taskSets(taskSetIndex).tasks.length == numTasks)
+ completeShuffleMapStageSuccessfully(stageId, 1, 2, checksumVal =
checksumVal)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId) ===
Some(Seq.empty))
+ }
+
test("SPARK-25341: continuous indeterminate stage roll back") {
// shuffleMapRdd1/2/3 are all indeterminate.
val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
@@ -3454,17 +3467,6 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2))
scheduler.resubmitFailedStages()
- def checkAndCompleteRetryStage(
- taskSetIndex: Int,
- stageId: Int,
- shuffleId: Int): Unit = {
- assert(taskSets(taskSetIndex).stageId == stageId)
- assert(taskSets(taskSetIndex).stageAttemptId == 1)
- assert(taskSets(taskSetIndex).tasks.length == 2)
- completeShuffleMapStageSuccessfully(stageId, 1, 2)
- assert(mapOutputTracker.findMissingPartitions(shuffleId) ===
Some(Seq.empty))
- }
-
// Check all indeterminate stage roll back.
checkAndCompleteRetryStage(3, 0, shuffleId1)
checkAndCompleteRetryStage(4, 1, shuffleId2)
@@ -3477,6 +3479,253 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assertDataStructuresEmpty()
}
+ // Construct the scenario of stages with checksum mismatches and FetchFailed.
+ private def constructChecksumMismatchStageFetchFailed(): (Int, Int) = {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true
+ )
+ val shuffleId1 = shuffleDep1.shuffleId
+ val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker =
mapOutputTracker)
+
+ val shuffleDep2 = new ShuffleDependency(
+ shuffleMapRdd2,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true
+ )
+ val shuffleId2 = shuffleDep2.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker =
mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1))
+
+ // Finish the first shuffle map stage.
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+
+ // The first task of the second shuffle map stage failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Finish the second task of the second shuffle map stage.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(1), Success, makeMapStatus("hostB", 2),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+ (shuffleId1, shuffleId2)
+ }
+
+ // Construct the scenario of stages with checksum mismatches and FetchFailed.
+ // This function assumes that the input `mapRdd` has a single stage with 2
partitions.
+ private def constructChecksumMismatchStageFetchFailed(mapRdd: MyRDD): Unit =
{
+ val shuffleDep = new ShuffleDependency(
+ mapRdd,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true
+ )
+ val shuffleId = shuffleDep.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker =
mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1))
+
+ completeShuffleMapStageSuccessfully(
+ 0, 0, numShufflePartitions = 2, Seq("hostA", "hostB"), checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId) ===
Some(Seq.empty))
+
+ // Fail the first task of the result stage with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
+ null))
+
+ // Finish the second task of the result stage.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(1), Success, 42,
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ // Shuffle blocks of "hostA" is lost, so first task of the shuffle map
stage and
+ // result stage needs to retry.
+ assert(failedStages.map(_.id) == Seq(0, 1))
+ assert(failedStages.forall(_.findMissingPartitions() == Seq(0)))
+
+ scheduler.resubmitFailedStages()
+
+ // First shuffle map stage reran failed tasks with a different checksum.
+ completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+ }
+
+ private def assertChecksumMismatchResultStageFailToRollback(mapRdd: MyRDD):
Unit = {
+ constructChecksumMismatchStageFetchFailed(mapRdd)
+
+ // The job should fail because Spark can't rollback the result stage.
+ assert(failure != null && failure.getMessage.contains("Spark cannot
rollback"))
+ }
+
+ private def assertChecksumMismatchResultStageNotRolledBack(mapRdd: MyRDD):
Unit = {
+ constructChecksumMismatchStageFetchFailed(mapRdd)
+
+ assert(failure == null, "job should not fail")
+ // Result stage success, all job ended.
+ complete(taskSets(3), Seq((Success, 41)))
+ assert(results === Map(0 -> 41, 1 -> 42))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-53575: abort stage while using old fetch protocol") {
+ conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true")
+ constructChecksumMismatchStageFetchFailed()
+
+ scheduler.resubmitFailedStages()
+ completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+
+ // The job should fail because Spark can't rollback the shuffle map stage
while
+ // using old protocol.
+ assert(failure != null && failure.getMessage.contains(
+ "Spark can only do this while using the new shuffle block fetching
protocol"))
+ }
+
+ test("SPARK-53575: retry all the succeeding stages when the map stage has
checksum mismatches") {
+ val (shuffleId1, shuffleId2) =
+ constructChecksumMismatchStageFetchFailed()
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ // Shuffle blocks of "hostA" is lost, so first task of the
`shuffleMapRdd1` and
+ // `shuffleMapRdd2` needs to retry.
+ assert(failedStages.map(_.id) == Seq(0, 1))
+ assert(failedStages.forall(_.findMissingPartitions() == Seq(0)))
+
+ scheduler.resubmitFailedStages()
+
+ // First shuffle map stage reran failed tasks with a different checksum.
+ checkAndCompleteRetryStage(2, 0, shuffleId1, numTasks = 1, checksumVal =
101)
+
+ // Second shuffle map stage reran all tasks.
+ checkAndCompleteRetryStage(3, 1, shuffleId2, numTasks = 2)
+
+ complete(taskSets(4), Seq((Success, 11), (Success, 12)))
+
+ // Job successful ended.
+ assert(results === Map(0 -> 11, 1 -> 12))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-53575: continuous checksum mismatch stage roll back") {
+ // shuffleMapRdd1/2 have checksum mismatches, and shuffleMapRdd2/3
requires full stage retries.
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true
+ )
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ val shuffleMapRdd2 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ val shuffleDep2 = new ShuffleDependency(
+ shuffleMapRdd2,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true
+ )
+ val shuffleId2 = shuffleDep2.shuffleId
+
+ val shuffleMapRdd3 = new MyRDD(
+ sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+ val shuffleDep3 = new ShuffleDependency(
+ shuffleMapRdd3,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true
+ )
+ val shuffleId3 = shuffleDep3.shuffleId
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker =
mapOutputTracker)
+
+ submit(finalRdd, Array(0, 1), properties = new Properties())
+
+ // Finish the first 2 shuffle map stages.
+ completeShuffleMapStageSuccessfully(0, 0, 2, Seq("hostA", "hostB"),
checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+ completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostA", "hostB"),
checksumVal = 200)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) ===
Some(Seq.empty))
+
+ // Fail the first task of the third shuffle map stage with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId2, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Finish the second task of the third shuffle map stage.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(1), Success, makeMapStatus("hostB", 2),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+ mapOutputTracker.removeOutputsOnHost("hostA")
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ // Shuffle blocks of "hostA" is lost, so first task of the
`shuffleMapRdd2` and
+ // `shuffleMapRdd3` needs to retry.
+ assert(failedStages.map(_.id) == Seq(1, 2))
+ assert(failedStages.forall(_.findMissingPartitions() == Seq(0)))
+
+ scheduler.resubmitFailedStages()
+
+ // First shuffle map stage reran failed tasks with a different checksum.
+ checkAndCompleteRetryStage(3, 0, shuffleId1, numTasks = 1, checksumVal =
101)
+ // Second and third shuffle map stages reran all tasks with a different
checksum.
+ checkAndCompleteRetryStage(4, 1, shuffleId2, numTasks = 2, checksumVal =
201)
+ checkAndCompleteRetryStage(5, 2, shuffleId3, numTasks = 2, checksumVal =
301)
+ // Result stage success, all job ended.
+ complete(taskSets(6), Seq((Success, 11), (Success, 12)))
+ assert(results === Map(0 -> 11, 1 -> 12))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-53575: cannot rollback a result stage") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+ assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+ }
+
+ test("SPARK-53575: local checkpoint fail to rollback (checkpointed before)")
{
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+ shuffleMapRdd.localCheckpoint()
+ shuffleMapRdd.doCheckpoint()
+ assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+ }
+
+ test("SPARK-53575: local checkpoint fail to rollback (checkpointing now)") {
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+ shuffleMapRdd.localCheckpoint()
+ assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+ }
+
+ test("SPARK-53575: reliable checkpoint can avoid rollback (checkpointed
before)") {
+ withTempDir { dir =>
+ sc.setCheckpointDir(dir.getCanonicalPath)
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+ shuffleMapRdd.checkpoint()
+ shuffleMapRdd.doCheckpoint()
+ assertChecksumMismatchResultStageNotRolledBack(shuffleMapRdd)
+ }
+ }
+
+ test("SPARK-53575: reliable checkpoint fail to rollback (checkpointing
now)") {
+ withTempDir { dir =>
+ sc.setCheckpointDir(dir.getCanonicalPath)
+ val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+ shuffleMapRdd.checkpoint()
+ assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+ }
+ }
+
test("SPARK-29042: Sampled RDD with unordered input should be
indeterminate") {
val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = false)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 17b8dd493cf8..477d09d29a05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -890,6 +890,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ private[spark] val SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED =
+
buildConf("spark.sql.shuffle.orderIndependentChecksum.enableFullRetryOnMismatch")
+ .doc("Whether to retry all tasks of a consumer stage when we detect
checksum mismatches " +
+ "with its producer stages.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize")
.internal()
@@ -6651,6 +6659,9 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def shuffleOrderIndependentChecksumEnabled: Boolean =
getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)
+ def shuffleChecksumMismatchFullRetryEnabled: Boolean =
+ getConf(SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED)
+
def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS)
def objectLevelCollationsEnabled: Boolean =
getConf(OBJECT_LEVEL_COLLATIONS_ENABLED)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 9c86bbb606a5..f052bd906880 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -480,19 +480,22 @@ object ShuffleExchangeExec {
// Now, we manually create a ShuffleDependency. Because pairs in
rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the
expected range
// [0, part.numPartitions - 1]. The partitioner of this is a
PartitionIdPassthrough.
- val checksumSize =
- if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) {
+ val checksumSize = {
+ if (SQLConf.get.shuffleOrderIndependentChecksumEnabled ||
+ SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) {
part.numPartitions
} else {
0
}
+ }
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
- rowBasedChecksums =
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize))
+ rowBasedChecksums =
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize),
+ checksumMismatchFullRetryEnabled =
SQLConf.get.shuffleChecksumMismatchFullRetryEnabled)
dependency
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
index 0fe660312210..abcd346c3277 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.test.SQLTestUtils
class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils {
override def spark: SparkSession = SparkSession.builder()
.master("local")
- .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value =
true)
.config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5)
.config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key,
value = false)
.getOrCreate()
@@ -39,26 +38,34 @@ class MapStatusEndToEndSuite extends SparkFunSuite with
SQLTestUtils {
}
test("Propagate checksum from executor to driver") {
- assert(spark.sparkContext.conf
- .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true")
-
assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") ==
"true")
- assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism")
== "5")
- assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5")
-
assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled")
+
assert(spark.sparkContext.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key)
== "5")
+ assert(spark.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) == "5")
+
assert(spark.sparkContext.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key)
== "false")
-
assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled")
== "false")
+
assert(spark.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key)
== "false")
- withTable("t") {
- spark.range(1000).repartition(10).write.mode("overwrite").
- saveAsTable("t")
- }
+ var shuffleId = 0
+ Seq(("true", "false"), ("false", "true"), ("true", "true")).foreach {
+ case (orderIndependentChecksumEnabled: String,
checksumMismatchFullRetryEnabled: String) =>
+ withSQLConf(
+ SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key ->
+ orderIndependentChecksumEnabled,
+ SQLConf.SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key ->
+ checksumMismatchFullRetryEnabled) {
+ withTable("t") {
+ spark.range(1000).repartition(10).write.mode("overwrite").
+ saveAsTable("t")
+ }
- val shuffleStatuses = spark.sparkContext.env.mapOutputTracker.
- asInstanceOf[MapOutputTrackerMaster].shuffleStatuses
- assert(shuffleStatuses.size == 1)
+ val shuffleStatuses = spark.sparkContext.env.mapOutputTracker.
+ asInstanceOf[MapOutputTrackerMaster].shuffleStatuses
+ assert(shuffleStatuses.contains(shuffleId))
- val mapStatuses = shuffleStatuses(0).mapStatuses
- assert(mapStatuses.length == 5)
- assert(mapStatuses.forall(_.checksumValue != 0))
+ val mapStatuses = shuffleStatuses(shuffleId).mapStatuses
+ assert(mapStatuses.length == 5)
+ assert(mapStatuses.forall(_.checksumValue != 0))
+ shuffleId += 1
+ }
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]