This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9ab693c01ed8 [SPARK-51756][CORE] Computes RowBasedChecksum in
ShuffleWriters
9ab693c01ed8 is described below
commit 9ab693c01ed85fdb7a41f0a1f6403335e148f2d6
Author: Tengfei Huang <[email protected]>
AuthorDate: Tue Sep 9 19:34:09 2025 -0500
[SPARK-51756][CORE] Computes RowBasedChecksum in ShuffleWriters
### What changes were proposed in this pull request?
This PR computes RowBasedChecksum for ShuffleWriters, which is controlled
under spark.shuffle.rowbased.checksum.enabled.
If enabled, Spark will calculate the RowBasedChecksum values for each
partition and each map output and returns the values from executors to the
driver. Different from the previous shuffle Checksum, RowBasedChecksum is
independent of the input row order, which is used to detect whether different
task attempts of the same partition produce different output data or not (key
or value). In case the output data has changed across retries, Spark will need
to retry all tasks of the consumer [...]
This PR contains only the RowBasedChecksum computation. In next PR, I plan
to trigger the full stage retry when we detect checksum mismatches.
### Why are the changes needed?
Problem:
Spark's resilience features can cause an RDD to be partially recomputed,
e.g. when an executor is lost due to downscaling, or due to a spot instance
kill. When the output of a nondeterministic task is recomputed, Spark does not
always recompute everything that depends on this task's output. In some cases,
some subsequent computations are based on the output of one "attempt" of the
task, while other subsequent computations are based on another "attempt".
This could be problematic when the producer stage is non-deterministic. In
which case, the second attempt of the same task can produce output that is very
different from the first one. For example, if the stage uses a round-robin
partitioning, some of the output data could be placed in different partitions
in different task attempts. This could lead to incorrect results unless we
retry the whole consumer stage that depends on retried non-deterministic stage.
Below is an example of this.
Example:
Let’s say we have Stage 1 and Stage 2, where Stage 1 is the producer and
Stage 2 is the consumer. Assume that the data produced by Task 2 were lost due
to some reason while Stage 2 is executing. Further assume that at this point,
Task 1 of Stage 2 has already gotten all its inputs and finishes, while Task 2
of Stage 2 fails with data fetch failures.
<img width="600" alt="example 1"
src="https://github.com/user-attachments/assets/549d1d90-3a8c-43e3-a891-1a6c614e9f24"
/>
Task 2 of Stage 1 will be retried to reproduce the data, and after which
Task 2 of Stage 2 is retried. Eventually, Task 1 and Task 2 of Stage 2 produces
the result which contains all 4 tuples {t1, t2, t3, t4} as shown in the example
graph.
<img width="720" alt="example 2"
src="https://github.com/user-attachments/assets/bebf03d5-f05e-46b6-8f78-bfad08999867"
/>
Now, let’s assume that Stage 1 is non-deterministic (e.g., when using
round-robin partitioning and the input data is not ordering), and Task 2 places
tuple t3 for Partition 1 and tuple t4 for Partition 2 in its first attempt. It
places tuple t4 for Partition 1 and tuple t3 for Partition 2 in its second
attempt. When Task 2 of Stage 2 is retried, instead of reading {t2, t4} as it
should, it reads {t2, t3} as its input. The result generated by Stage 2 is {t1,
t2, t3, t3}, which is inaccurate.
<img width="720" alt="example 3"
src="https://github.com/user-attachments/assets/730fac0f-dfc3-4392-a74f-ed3e0d11e665"
/>
The problem can be avoided if we retry all tasks of Stage 2. As all tasks
read consistent data, we can produce result correctly, regardless of how the
retried of Stage 1 Task 2 would partition the data.
<img width="720" alt="example 4"
src="https://github.com/user-attachments/assets/a501a33e-97bb-4a01-954f-bc7d0f01f3e6"
/>
Proposal:
To avoid correctness issues produce by non-deterministic stage with partial
retry, we propose an approach which first try to detect inconsistent data that
might be generated by different task attempts of a non-deterministic stage. For
example, whether all the data partitions generated by Task 2 in the first
attempt are the same as the all the data partitions generated by the second
attempt. We retry the entire consumer stages if inconsistent data is detected.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tested
Benchmark test:
tpcds (10gb): the overhead of checksum computation with UnsafeRowChecksum
is 0.4%.
tpcds (3tb): the overhead of checksum computation with UnsafeRowChecksum is
0.72%.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #50230 from JiexingLi/shuffle-checksum.
Lead-authored-by: Tengfei Huang <[email protected]>
Co-authored-by: Jiexing Li <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../spark/shuffle/checksum/RowBasedChecksum.scala | 80 +++++++++++
.../shuffle/sort/BypassMergeSortShuffleWriter.java | 31 ++++-
.../spark/shuffle/sort/UnsafeShuffleWriter.java | 36 +++--
.../util/ExposedBufferByteArrayOutputStream.java} | 26 +---
.../main/scala/org/apache/spark/Dependency.scala | 28 +++-
.../scala/org/apache/spark/MapOutputTracker.scala | 17 ++-
.../org/apache/spark/scheduler/MapStatus.scala | 46 +++++--
.../spark/shuffle/sort/SortShuffleWriter.scala | 22 ++-
.../spark/util/collection/ExternalSorter.scala | 23 +++-
.../shuffle/sort/UnsafeShuffleWriterSuite.java | 49 ++++++-
.../org/apache/spark/MapOutputTrackerSuite.scala | 6 +-
.../spark/MapStatusesSerDeserBenchmark.scala | 2 +-
.../apache/spark/scheduler/DAGSchedulerSuite.scala | 52 ++++++-
.../spark/shuffle/ShuffleChecksumTestHelper.scala | 14 ++
.../checksum/OutputStreamRowBasedChecksum.scala | 64 +++++++++
.../sort/BypassMergeSortShuffleWriterSuite.scala | 58 +++++++-
.../shuffle/sort/SortShuffleWriterSuite.scala | 63 ++++++++-
.../catalyst/expressions/UnsafeRowChecksum.scala | 53 ++++++++
.../org/apache/spark/sql/internal/SQLConf.scala | 18 +++
.../execution/exchange/ShuffleExchangeExec.scala | 11 +-
.../apache/spark/sql/MapStatusEndToEndSuite.scala | 64 +++++++++
.../apache/spark/sql/UnsafeRowChecksumSuite.scala | 149 +++++++++++++++++++++
22 files changed, 841 insertions(+), 71 deletions(-)
diff --git
a/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala
b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala
new file mode 100644
index 000000000000..886296dc8a82
--- /dev/null
+++
b/core/src/main/java/org/apache/spark/shuffle/checksum/RowBasedChecksum.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.checksum
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+
+/**
+ * A class for computing checksum for input (key, value) pairs. The checksum
is independent of
+ * the order of the input (key, value) pairs. It is done by computing a
checksum for each row
+ * first, then computing the XOR and SUM for all the row checksums and mixing
these two values
+ * as the final checksum.
+ */
+abstract class RowBasedChecksum() extends Serializable with Logging {
+ private val ROTATE_POSITIONS = 27
+ private var hasError: Boolean = false
+ private var checksumXor: Long = 0
+ private var checksumSum: Long = 0
+
+ /**
+ * Returns the checksum value. It returns the default checksum value (0) if
there
+ * are any errors encountered during the checksum computation.
+ */
+ def getValue: Long = {
+ if (!hasError) {
+ // Here we rotate the `checksumSum` to transforms these two values into
a single, strong
+ // composite checksum by ensuring their bit patterns are thoroughly
mixed.
+ checksumXor ^ rotateLeft(checksumSum)
+ } else {
+ 0
+ }
+ }
+
+ /** Updates the row-based checksum with the given (key, value) pair. Not
thread safe. */
+ def update(key: Any, value: Any): Unit = {
+ if (!hasError) {
+ try {
+ val rowChecksumValue = calculateRowChecksum(key, value)
+ checksumXor = checksumXor ^ rowChecksumValue
+ checksumSum += rowChecksumValue
+ } catch {
+ case NonFatal(e) =>
+ logError("Checksum computation encountered error: ", e)
+ hasError = true
+ }
+ }
+ }
+
+ /** Computes and returns the checksum value for the given (key, value) pair
*/
+ protected def calculateRowChecksum(key: Any, value: Any): Long
+
+ // Rotate the value by shifting the bits by `ROTATE_POSITIONS` positions to
the left.
+ private def rotateLeft(value: Long): Long = {
+ (value << ROTATE_POSITIONS) | (value >>> (64 - ROTATE_POSITIONS))
+ }
+}
+
+object RowBasedChecksum {
+ def getAggregatedChecksumValue(rowBasedChecksums: Array[RowBasedChecksum]):
Long = {
+ Option(rowBasedChecksums)
+ .map(_.foldLeft(0L)((acc, c) => acc * 31L + c.getValue))
+ .getOrElse(0L)
+ }
+}
diff --git
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 8072a432ab11..5acc66e12063 100644
---
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -32,6 +32,7 @@ import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.apache.spark.internal.SparkLogger;
@@ -53,6 +54,7 @@ import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
@@ -104,6 +106,13 @@ final class BypassMergeSortShuffleWriter<K, V>
private long[] partitionLengths;
/** Checksum calculator for each partition. Empty when shuffle checksum
disabled. */
private final Checksum[] partitionChecksums;
+ /**
+ * Checksum calculator for each partition. Different from the above Checksum,
+ * RowBasedChecksum is independent of the input row order, which is used to
+ * detect whether different task attempts of the same partition produce
different
+ * output data or not.
+ */
+ private final RowBasedChecksum[] rowBasedChecksums;
/**
* Are we in the process of stopping? Because map tasks can call stop() with
success = true
@@ -132,6 +141,7 @@ final class BypassMergeSortShuffleWriter<K, V>
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
+ this.rowBasedChecksums = dep.rowBasedChecksums();
}
@Override
@@ -144,7 +154,7 @@ final class BypassMergeSortShuffleWriter<K, V>
partitionLengths = mapOutputWriter.commitAllPartitions(
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(), partitionLengths, mapId);
+ blockManager.shuffleServerId(), partitionLengths, mapId,
getAggregatedChecksumValue());
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -171,7 +181,11 @@ final class BypassMergeSortShuffleWriter<K, V>
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
- partitionWriters[partitioner.getPartition(key)].write(key,
record._2());
+ final int partitionId = partitioner.getPartition(key);
+ partitionWriters[partitionId].write(key, record._2());
+ if (rowBasedChecksums.length > 0) {
+ rowBasedChecksums[partitionId].update(key, record._2());
+ }
}
for (int i = 0; i < numPartitions; i++) {
@@ -182,7 +196,7 @@ final class BypassMergeSortShuffleWriter<K, V>
partitionLengths = writePartitionedData(mapOutputWriter);
mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(), partitionLengths, mapId);
+ blockManager.shuffleServerId(), partitionLengths, mapId,
getAggregatedChecksumValue());
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
@@ -199,6 +213,17 @@ final class BypassMergeSortShuffleWriter<K, V>
return partitionLengths;
}
+ // For test only.
+ @VisibleForTesting
+ RowBasedChecksum[] getRowBasedChecksums() {
+ return rowBasedChecksums;
+ }
+
+ @VisibleForTesting
+ long getAggregatedChecksumValue() {
+ return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums);
+ }
+
/**
* Concatenate all of the per-partition files into a single combined file.
*
diff --git
a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 36a148762736..e3ecfed32348 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -59,9 +59,11 @@ import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.util.ExposedBufferByteArrayOutputStream;
import org.apache.spark.util.Utils;
@Private
@@ -93,15 +95,16 @@ public class UnsafeShuffleWriter<K, V> extends
ShuffleWriter<K, V> {
@Nullable private long[] partitionLengths;
private long peakMemoryUsedBytes = 0;
- /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
- private static final class MyByteArrayOutputStream extends
ByteArrayOutputStream {
- MyByteArrayOutputStream(int size) { super(size); }
- public byte[] getBuf() { return buf; }
- }
-
- private MyByteArrayOutputStream serBuffer;
+ private ExposedBufferByteArrayOutputStream serBuffer;
private SerializationStream serOutputStream;
+ /**
+ * RowBasedChecksum calculator for each partition. RowBasedChecksum is
independent
+ * of the input row order, which is used to detect whether different task
attempts
+ * of the same partition produce different output data or not.
+ */
+ private final RowBasedChecksum[] rowBasedChecksums;
+
/**
* Are we in the process of stopping? Because map tasks can call stop() with
success = true
* and then call stop() with success = false if they get an exception, we
want to make sure
@@ -141,6 +144,7 @@ public class UnsafeShuffleWriter<K, V> extends
ShuffleWriter<K, V> {
(int) (long)
sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.mergeBufferSizeInBytes =
(int) (long)
sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024;
+ this.rowBasedChecksums = dep.rowBasedChecksums();
open();
}
@@ -162,6 +166,17 @@ public class UnsafeShuffleWriter<K, V> extends
ShuffleWriter<K, V> {
return peakMemoryUsedBytes;
}
+ // For test only.
+ @VisibleForTesting
+ RowBasedChecksum[] getRowBasedChecksums() {
+ return rowBasedChecksums;
+ }
+
+ @VisibleForTesting
+ long getAggregatedChecksumValue() {
+ return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums);
+ }
+
/**
* This convenience method should only be called in test code.
*/
@@ -210,7 +225,7 @@ public class UnsafeShuffleWriter<K, V> extends
ShuffleWriter<K, V> {
partitioner.numPartitions(),
sparkConf,
writeMetrics);
- serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
+ serBuffer = new
ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
serOutputStream = serializer.serializeStream(serBuffer);
}
@@ -233,7 +248,7 @@ public class UnsafeShuffleWriter<K, V> extends
ShuffleWriter<K, V> {
}
}
mapStatus = MapStatus$.MODULE$.apply(
- blockManager.shuffleServerId(), partitionLengths, mapId);
+ blockManager.shuffleServerId(), partitionLengths, mapId,
getAggregatedChecksumValue());
}
@VisibleForTesting
@@ -251,6 +266,9 @@ public class UnsafeShuffleWriter<K, V> extends
ShuffleWriter<K, V> {
sorter.insertRecord(
serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize,
partitionId);
+ if (rowBasedChecksums.length > 0) {
+ rowBasedChecksums[partitionId].update(key, record._2());
+ }
}
@VisibleForTesting
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
b/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java
similarity index 54%
copy from
core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
copy to
core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java
index 8be103b7be86..bd59bd176fb9 100644
---
a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
+++
b/core/src/main/java/org/apache/spark/util/ExposedBufferByteArrayOutputStream.java
@@ -15,26 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle
+package org.apache.spark.util;
-import java.io.File
+import java.io.ByteArrayOutputStream;
-trait ShuffleChecksumTestHelper {
-
- /**
- * Ensure that the checksum values are consistent between write and read
side.
- */
- def compareChecksums(
- numPartition: Int,
- algorithm: String,
- checksum: File,
- data: File,
- index: File): Unit = {
- assert(checksum.exists(), "Checksum file doesn't exist")
- assert(data.exists(), "Data file doesn't exist")
- assert(index.exists(), "Index file doesn't exist")
-
- assert(ShuffleChecksumUtils.compareChecksums(numPartition, algorithm,
checksum, data, index),
- "checksum must be consistent at both write and read sides")
- }
+/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+public final class ExposedBufferByteArrayOutputStream extends
ByteArrayOutputStream {
+ public ExposedBufferByteArrayOutputStream(int size) { super(size); }
+ public byte[] getBuf() { return buf; }
}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala
b/core/src/main/scala/org/apache/spark/Dependency.scala
index 745faf866ceb..93a2bbe25157 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
+import org.apache.spark.shuffle.checksum.RowBasedChecksum
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -59,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends
Dependency[T] {
override def rdd: RDD[T] = _rdd
}
+object ShuffleDependency {
+ private[spark] val EMPTY_ROW_BASED_CHECKSUMS: Array[RowBasedChecksum] =
Array.empty
+}
/**
* :: DeveloperApi ::
@@ -74,6 +78,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends
Dependency[T] {
* @param aggregator map/reduce-side aggregator for RDD's shuffle
* @param mapSideCombine whether to perform partial aggregation (also known as
map-side combine)
* @param shuffleWriterProcessor the processor to control the write behavior
in ShuffleMapTask
+ * @param rowBasedChecksums the row-based checksums for each shuffle partition
*/
@DeveloperApi
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
@@ -83,9 +88,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false,
- val shuffleWriterProcessor: ShuffleWriteProcessor = new
ShuffleWriteProcessor)
+ val shuffleWriterProcessor: ShuffleWriteProcessor = new
ShuffleWriteProcessor,
+ val rowBasedChecksums: Array[RowBasedChecksum] =
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS)
extends Dependency[Product2[K, V]] with Logging {
+ def this(
+ rdd: RDD[_ <: Product2[K, V]],
+ partitioner: Partitioner,
+ serializer: Serializer,
+ keyOrdering: Option[Ordering[K]],
+ aggregator: Option[Aggregator[K, V, C]],
+ mapSideCombine: Boolean,
+ shuffleWriterProcessor: ShuffleWriteProcessor) = {
+ this(
+ rdd,
+ partitioner,
+ serializer,
+ keyOrdering,
+ aggregator,
+ mapSideCombine,
+ shuffleWriterProcessor,
+ ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS
+ )
+ }
+
if (mapSideCombine) {
require(aggregator.isDefined, "Map-side combine without Aggregator
specified!")
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 9b2d3d748ed4..3f823b60156a 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap,
LinkedBlockingQueue, ThreadPoolE
import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.collection
-import scala.collection.mutable.{HashMap, ListBuffer, Map}
+import scala.collection.mutable.{HashMap, ListBuffer, Map, Set}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
@@ -99,6 +99,12 @@ private class ShuffleStatus(
*/
val mapStatusesDeleted = new Array[MapStatus](numPartitions)
+ /**
+ * Keep the indices of the Map tasks whose checksums are different across
retries.
+ * Exposed for testing.
+ */
+ private[spark] val checksumMismatchIndices: Set[Int] = Set()
+
/**
* MergeStatus for each shuffle partition when push-based shuffle is
enabled. The index of the
* array is the shuffle partition id (reduce id). Each value in the array is
the MergeStatus for
@@ -169,6 +175,15 @@ private class ShuffleStatus(
} else {
mapIdToMapIndex.remove(currentMapStatus.mapId)
}
+ logDebug(s"Checksum of map output for task ${status.mapId} is
${status.checksumValue}")
+
+ val preStatus =
+ if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else
mapStatusesDeleted(mapIndex)
+ if (preStatus != null && preStatus.checksumValue != status.checksumValue) {
+ logInfo(s"Checksum of map output changes from ${preStatus.checksumValue}
to " +
+ s"${status.checksumValue} for task ${status.mapId}.")
+ checksumMismatchIndices.add(mapIndex)
+ }
mapStatuses(mapIndex) = status
mapIdToMapIndex(status.mapId) = mapIndex
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 113521453ad7..e348b6a5f149 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -58,6 +58,12 @@ private[spark] sealed trait MapStatus extends
ShuffleOutputStatus {
* partitionId of the task or taskContext.taskAttemptId is used.
*/
def mapId: Long
+
+ /**
+ * The checksum value of this shuffle map task, which can be used to
evaluate whether the
+ * output data has changed across different map task retries.
+ */
+ def checksumValue: Long = 0
}
@@ -74,11 +80,12 @@ private[spark] object MapStatus {
def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
- mapTaskId: Long): MapStatus = {
+ mapTaskId: Long,
+ checksumVal: Long = 0): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId)
+ HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, checksumVal)
} else {
- new CompressedMapStatus(loc, uncompressedSizes, mapTaskId)
+ new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, checksumVal)
}
}
@@ -119,18 +126,24 @@ private[spark] object MapStatus {
* @param loc location where the task is being executed.
* @param compressedSizes size of the blocks, indexed by reduce partition id.
* @param _mapTaskId unique task id for the task
+ * @param _checksumVal the checksum value for the task
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
private[this] var compressedSizes: Array[Byte],
- private[this] var _mapTaskId: Long)
+ private[this] var _mapTaskId: Long,
+ private[this] var _checksumVal: Long = 0)
extends MapStatus with Externalizable {
// For deserialization only
- protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1)
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1, 0)
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId:
Long) = {
- this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId)
+ def this(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long,
+ checksumVal: Long) = {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId,
checksumVal)
}
override def location: BlockManagerId = loc
@@ -145,11 +158,14 @@ private[spark] class CompressedMapStatus(
override def mapId: Long = _mapTaskId
+ override def checksumValue: Long = _checksumVal
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException
{
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
out.writeLong(_mapTaskId)
+ out.writeLong(_checksumVal)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -158,6 +174,7 @@ private[spark] class CompressedMapStatus(
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
_mapTaskId = in.readLong()
+ _checksumVal = in.readLong()
}
}
@@ -172,6 +189,7 @@ private[spark] class CompressedMapStatus(
* @param avgSize average size of the non-empty and non-huge blocks
* @param hugeBlockSizes sizes of huge blocks by their reduceId.
* @param _mapTaskId unique task id for the task
+ * @param _checksumVal checksum value for the task
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
@@ -179,7 +197,8 @@ private[spark] class HighlyCompressedMapStatus private (
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
- private[this] var _mapTaskId: Long)
+ private[this] var _mapTaskId: Long,
+ private[this] var _checksumVal: Long = 0)
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during
deserialization
@@ -187,7 +206,7 @@ private[spark] class HighlyCompressedMapStatus private (
|| numNonEmptyBlocks == 0 || _mapTaskId > 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null, -1) // For
deserialization only
+ protected def this() = this(null, -1, null, -1, null, -1, 0) // For
deserialization only
override def location: BlockManagerId = loc
@@ -209,6 +228,8 @@ private[spark] class HighlyCompressedMapStatus private (
override def mapId: Long = _mapTaskId
+ override def checksumValue: Long = _checksumVal
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException
{
loc.writeExternal(out)
emptyBlocks.serialize(out)
@@ -219,6 +240,7 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeByte(kv._2)
}
out.writeLong(_mapTaskId)
+ out.writeLong(_checksumVal)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -236,6 +258,7 @@ private[spark] class HighlyCompressedMapStatus private (
}
hugeBlockSizes = hugeBlockSizesImpl
_mapTaskId = in.readLong()
+ _checksumVal = in.readLong()
}
}
@@ -243,7 +266,8 @@ private[spark] object HighlyCompressedMapStatus {
def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
- mapTaskId: Long): HighlyCompressedMapStatus = {
+ mapTaskId: Long,
+ checksumVal: Long = 0): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a
zero-sized
// block as being non-empty (or vice-versa) when using the average block
size.
var i = 0
@@ -310,6 +334,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizes, mapTaskId)
+ hugeBlockSizes, mapTaskId, checksumVal)
}
}
diff --git
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 3be7d24f7e4e..a7ac20016a0e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -23,6 +23,7 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.checksum.RowBasedChecksum
import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
@@ -48,17 +49,31 @@ private[spark] class SortShuffleWriter[K, V, C](
private var partitionLengths: Array[Long] = _
+ def getRowBasedChecksums: Array[RowBasedChecksum] = {
+ if (sorter != null) {
+ sorter.getRowBasedChecksums
+ } else {
+ ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS
+ }
+ }
+
+ def getAggregatedChecksumValue: Long = {
+ if (sorter != null) sorter.getAggregatedChecksumValue else 0
+ }
+
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
new ExternalSorter[K, V, C](
- context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering,
dep.serializer)
+ context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering,
+ dep.serializer, dep.rowBasedChecksums)
} else {
// In this case we pass neither an aggregator nor an ordering to the
sorter, because we don't
// care whether the keys get sorted in each partition; that will be done
on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
- context, aggregator = None, Some(dep.partitioner), ordering = None,
dep.serializer)
+ context, aggregator = None, Some(dep.partitioner), ordering = None,
+ dep.serializer, dep.rowBasedChecksums)
}
sorter.insertAll(records)
@@ -69,7 +84,8 @@ private[spark] class SortShuffleWriter[K, V, C](
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter,
writeMetrics)
partitionLengths =
mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths,
mapId)
+ mapStatus =
+ MapStatus(blockManager.shuffleServerId, partitionLengths, mapId,
getAggregatedChecksumValue)
}
/** Close this writer, passing along whether the map completed */
diff --git
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 8dd207b25bb9..4da89a94201a 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -33,7 +33,7 @@ import org.apache.spark.internal.LogKeys.{NUM_BYTES,
TASK_ATTEMPT_ID}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.{ShufflePartitionPairsWriter,
ShuffleWriteMetricsReporter}
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter,
ShufflePartitionWriter}
-import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport
+import org.apache.spark.shuffle.checksum.{RowBasedChecksum,
ShuffleChecksumSupport}
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter,
ShuffleBlockId}
import org.apache.spark.util.{CompletionIterator, Utils => TryUtils}
@@ -97,7 +97,8 @@ private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
- serializer: Serializer = SparkEnv.get.serializer)
+ serializer: Serializer = SparkEnv.get.serializer,
+ rowBasedChecksums: Array[RowBasedChecksum] = Array.empty)
extends Spillable[WritablePartitionedPairCollection[K,
C]](context.taskMemoryManager())
with Logging with ShuffleChecksumSupport {
@@ -142,10 +143,16 @@ private[spark] class ExternalSorter[K, V, C](
private val forceSpillFiles = new ArrayBuffer[SpilledFile]
@volatile private var readingIterator: SpillableIterator = null
+ /** Checksum calculator for each partition. Empty when shuffle checksum
disabled. */
private val partitionChecksums = createPartitionChecksums(numPartitions,
conf)
def getChecksums: Array[Long] = getChecksumValues(partitionChecksums)
+ def getRowBasedChecksums: Array[RowBasedChecksum] = rowBasedChecksums
+
+ def getAggregatedChecksumValue: Long =
+ RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums)
+
// A comparator for keys K that orders them within a partition to allow
aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not
provided through by the
// user. (A partial ordering means that equal keys have
comparator.compare(k, k) = 0, but some
@@ -197,16 +204,24 @@ private[spark] class ExternalSorter[K, V, C](
while (records.hasNext) {
addElementsRead()
kv = records.next()
- map.changeValue((actualPartitioner.getPartition(kv._1), kv._1), update)
+ val partitionId = actualPartitioner.getPartition(kv._1)
+ map.changeValue((partitionId, kv._1), update)
maybeSpillCollection(usingMap = true)
+ if (rowBasedChecksums.nonEmpty) {
+ rowBasedChecksums(partitionId).update(kv._1, kv._2)
+ }
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
- buffer.insert(actualPartitioner.getPartition(kv._1), kv._1,
kv._2.asInstanceOf[C])
+ val partitionId = actualPartitioner.getPartition(kv._1)
+ buffer.insert(partitionId, kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
+ if (rowBasedChecksums.nonEmpty) {
+ rowBasedChecksums(partitionId).update(kv._1, kv._2)
+ }
}
}
}
diff --git
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index c55254e04f40..b13d8982ad0d 100644
---
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -49,6 +49,7 @@ import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
+import org.apache.spark.shuffle.checksum.RowBasedChecksum;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents;
import org.apache.spark.storage.*;
@@ -174,11 +175,18 @@ public class UnsafeShuffleWriterSuite implements
ShuffleChecksumTestHelper {
File file = (File) invocationOnMock.getArguments()[0];
return Utils.tempFileWith(file);
});
-
+ resetDependency(false);
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
+ }
+
+ private void resetDependency(boolean rowBasedChecksumEnabled) {
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
- when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
+ final int checksumSize = rowBasedChecksumEnabled ? NUM_PARTITIONS : 0;
+ final RowBasedChecksum[] rowBasedChecksums =
+ createPartitionRowBasedChecksums(checksumSize);
+ when(shuffleDep.rowBasedChecksums()).thenReturn(rowBasedChecksums);
}
private UnsafeShuffleWriter<Object, Object> createWriter(boolean
transferToEnabled)
@@ -613,6 +621,43 @@ public class UnsafeShuffleWriterSuite implements
ShuffleChecksumTestHelper {
assertSpillFilesWereCleanedUp();
}
+ @Test
+ public void testRowBasedChecksum() throws IOException, SparkException {
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
+ for (int i = 0; i < NUM_PARTITIONS; i++) {
+ for (int j = 0; j < 5; j++) {
+ dataToWrite.add(new Tuple2<>(i, i + j));
+ }
+ }
+
+ long[] checksumValues = new long[0];
+ long aggregatedChecksumValue = 0;
+ try {
+ for (int i = 0; i < 100; i++) {
+ resetDependency(true);
+ final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+ Collections.shuffle(dataToWrite);
+ writer.write(dataToWrite.iterator());
+ writer.stop(true);
+
+ if (i == 0) {
+ checksumValues =
getRowBasedChecksumValues(writer.getRowBasedChecksums());
+ assertEquals(checksumValues.length, NUM_PARTITIONS);
+ Arrays.stream(checksumValues).allMatch(v -> v > 0);
+
+ aggregatedChecksumValue = writer.getAggregatedChecksumValue();
+ assert(aggregatedChecksumValue != 0);
+ } else {
+ assertArrayEquals(checksumValues,
+ getRowBasedChecksumValues(writer.getRowBasedChecksums()));
+ assertEquals(aggregatedChecksumValue,
writer.getAggregatedChecksumValue());
+ }
+ }
+ } finally {
+ resetDependency(false);
+ }
+ }
+
@Test
public void testPeakMemoryUsed() throws Exception {
final long recordLengthBytes = 8;
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 68e366e9ad10..d2344b4e7291 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -272,7 +272,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with
LocalSparkContext {
masterTracker.registerShuffle(20, 100,
MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5))
+ BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5,
100))
}
val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
@@ -579,7 +579,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with
LocalSparkContext {
masterTracker.registerShuffle(20, 100,
MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5))
+ BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5,
100))
}
val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new
SecurityManager(conf))
@@ -626,7 +626,7 @@ class MapOutputTrackerSuite extends SparkFunSuite with
LocalSparkContext {
masterTracker.registerShuffle(20, 100,
MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5))
+ BlockManagerId("999", "mps", 1000), createArray(4000000, 0L), 5,
100))
}
masterTracker.registerMergeResult(20, 0,
MergeStatus(BlockManagerId("999", "mps", 1000), 0,
bitmap1, 1000L))
diff --git
a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
index 75f952d063d3..bd8766bd260e 100644
--- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
@@ -59,7 +59,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
Array.fill(blockSize) {
// Creating block size ranging from 0byte to 1GB
(r.nextDouble() * 1024 * 1024 * 1024).toLong
- }, i))
+ }, i, i * 100))
}
val shuffleStatus = tracker.shuffleStatuses.get(shuffleId).head
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index bf38c629f700..1ada81cbdd0e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1166,7 +1166,8 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
stageId: Int,
attemptIdx: Int,
numShufflePartitions: Int,
- hostNames: Seq[String] = Seq.empty[String]): Unit = {
+ hostNames: Seq[String] = Seq.empty[String],
+ checksumVal: Long = 0): Unit = {
def compareStageAttempt(taskSet: TaskSet): Boolean = {
taskSet.stageId == stageId && taskSet.stageAttemptId == attemptIdx
}
@@ -1181,7 +1182,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
} else {
s"host${('A' + idx).toChar}"
}
- (Success, makeMapStatus(hostName, numShufflePartitions))
+ (Success, makeMapStatus(hostName, numShufflePartitions, checksumVal =
checksumVal))
}.toSeq)
}
@@ -4852,6 +4853,44 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(mapStatuses.count(s => s != null && s.location.executorId ==
"hostB-exec") === 1)
}
+ /**
+ * In this test, we simulate a job where some tasks in a stage fail, and it
triggers the retry
+ * of the task in its previous stage. The two attempts of the same task in
the previous stage
+ * produce different shuffle checksums.
+ */
+ test("Tasks that produce different checksum across retries") {
+ setupStageAbortTest(sc)
+
+ val parts = 8
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+ submit(reduceRdd, (0 until parts).toArray)
+
+ // Complete stage 0 and then fail stage 1, and tasks in stage 0 produce a
checksum of 100.
+ completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = parts,
checksumVal = 100)
+ completeNextStageWithFetchFailure(1, 0, shuffleDep)
+
+ // Resubmit and confirm that now all is well.
+ scheduler.resubmitFailedStages()
+ assert(scheduler.runningStages.nonEmpty)
+ assert(!ended)
+
+ // Complete stage 0 and then stage 1, and the retried task in stage 0
produces a different
+ // checksum of 200.
+ completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts,
checksumVal = 200)
+ completeNextResultStageWithSuccess(1, 1)
+
+ // Confirm job finished successfully.
+ sc.listenerBus.waitUntilEmpty()
+ assert(ended)
+ assert(results == (0 until parts).map { idx => idx -> 42 }.toMap)
+ assert(
+
mapOutputTracker.shuffleStatuses(shuffleDep.shuffleId).checksumMismatchIndices.size
== 1)
+ assertDataStructuresEmpty()
+ mapOutputTracker.unregisterShuffle(shuffleDep.shuffleId)
+ }
+
Seq(true, false).foreach { registerMergeResults =>
test("SPARK-40096: Send finalize events even if shuffle merger blocks
indefinitely " +
s"with registerMergeResults is ${registerMergeResults}") {
@@ -5245,8 +5284,13 @@ class DAGSchedulerAbortStageOffSuite extends
DAGSchedulerSuite {
object DAGSchedulerSuite {
val mergerLocs = ArrayBuffer[BlockManagerId]()
- def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId:
Long = -1): MapStatus =
- MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes),
mapTaskId)
+ def makeMapStatus(
+ host: String,
+ reduces: Int,
+ sizes: Byte = 2,
+ mapTaskId: Long = -1,
+ checksumVal: Long = 0): MapStatus =
+ MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes),
mapTaskId, checksumVal)
def makeBlockManagerId(host: String, execId: Option[String] = None):
BlockManagerId = {
BlockManagerId(execId.getOrElse(host + "-exec"), host, 12345)
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
index 8be103b7be86..439d75ee9364 100644
---
a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
+++
b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
@@ -19,6 +19,8 @@ package org.apache.spark.shuffle
import java.io.File
+import org.apache.spark.shuffle.checksum.{OutputStreamRowBasedChecksum,
RowBasedChecksum}
+
trait ShuffleChecksumTestHelper {
/**
@@ -37,4 +39,16 @@ trait ShuffleChecksumTestHelper {
assert(ShuffleChecksumUtils.compareChecksums(numPartition, algorithm,
checksum, data, index),
"checksum must be consistent at both write and read sides")
}
+
+ def getRowBasedChecksumValues(rowBasedChecksums: Array[RowBasedChecksum]):
Array[Long] = {
+ if (rowBasedChecksums.isEmpty) {
+ Array.empty
+ } else {
+ rowBasedChecksums.map(_.getValue)
+ }
+ }
+
+ def createPartitionRowBasedChecksums(numPartitions: Int):
Array[RowBasedChecksum] = {
+ Array.tabulate(numPartitions)(_ => new
OutputStreamRowBasedChecksum("ADLER32"))
+ }
}
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala
b/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala
new file mode 100644
index 000000000000..3abec5f4bd65
--- /dev/null
+++
b/core/src/test/scala/org/apache/spark/shuffle/checksum/OutputStreamRowBasedChecksum.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.checksum
+
+import java.io.ObjectOutputStream
+import java.util.zip.Checksum
+
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper
+import org.apache.spark.util.ExposedBufferByteArrayOutputStream
+
+/**
+ * A Concrete implementation of RowBasedChecksum. The checksum for each row is
+ * computed by first converting the (key, value) pair to byte array using
OutputStreams,
+ * and then computing the checksum for the byte array.
+ * Note that this checksum computation is very expensive, and it is used only
in tests
+ * in the core component. A much cheaper implementation of RowBasedChecksum is
in
+ * UnsafeRowChecksum.
+ *
+ * @param checksumAlgorithm the algorithm used for computing checksum.
+ */
+class OutputStreamRowBasedChecksum(checksumAlgorithm: String)
+ extends RowBasedChecksum() {
+
+ private val DEFAULT_INITIAL_SER_BUFFER_SIZE = 32 * 1024
+
+ @transient private lazy val serBuffer =
+ new ExposedBufferByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE)
+ @transient private lazy val objOut = new ObjectOutputStream(serBuffer)
+
+ @transient
+ protected lazy val checksum: Checksum =
+ ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
+
+ override protected def calculateRowChecksum(key: Any, value: Any): Long = {
+ assert(checksum != null, "Checksum is null")
+
+ // Converts the (key, value) pair into byte array.
+ objOut.reset()
+ serBuffer.reset()
+ objOut.writeObject((key, value))
+ objOut.flush()
+ serBuffer.flush()
+
+ // Computes and returns the checksum for the byte array.
+ checksum.reset()
+ checksum.update(serBuffer.getBuf, 0, serBuffer.size())
+ checksum.getValue
+ }
+}
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index ce2aefa74229..c908c06b399d 100644
---
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -22,6 +22,7 @@ import java.util.UUID
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
@@ -74,8 +75,7 @@ class BypassMergeSortShuffleWriterSuite
)
val memoryManager = new TestMemoryManager(conf)
val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
- when(dependency.partitioner).thenReturn(new HashPartitioner(7))
- when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+ resetDependency(conf, rowBasedChecksumEnabled = false)
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
@@ -145,6 +145,20 @@ class BypassMergeSortShuffleWriterSuite
}
}
+ private def resetDependency(sc: SparkConf, rowBasedChecksumEnabled:
Boolean): Unit = {
+ reset(dependency)
+ val numPartitions = 7
+ when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions))
+ when(dependency.serializer).thenReturn(new JavaSerializer(sc))
+ val checksumSize = if (rowBasedChecksumEnabled) {
+ numPartitions
+ } else {
+ 0
+ }
+ val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize)
+ when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums)
+ }
+
test("write empty iterator") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
@@ -294,4 +308,44 @@ class BypassMergeSortShuffleWriterSuite
assert(checksumFile.length() === 8 * numPartition)
compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile,
indexFile)
}
+
+ test("Row-based checksums are independent of input row order") {
+ val records: List[(Int, Int)] = List(
+ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
+ (2, 2), (2, 3), (2, 4), (2, 5), (2, 6),
+ (3, 3), (3, 4), (3, 5), (3, 6), (3, 7),
+ (4, 4), (4, 5), (4, 6), (4, 7), (4, 8),
+ (5, 5), (5, 6), (5, 7), (5, 8), (5, 9),
+ (6, 6), (6, 7), (6, 8), (6, 9), (6, 10),
+ (7, 7), (7, 8), (7, 9), (7, 10), (7, 11))
+
+ var checksumValues : Array[Long] = Array[Long]()
+ var aggregatedChecksumValue = 0L
+ for (i <- 1 to 100) {
+ resetDependency(conf, rowBasedChecksumEnabled = true)
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ blockManager,
+ shuffleHandle,
+ 0L, // MapId
+ conf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleExecutorComponents)
+
+ writer.write(Random.shuffle(records).iterator)
+ writer.stop(/* success = */ true)
+
+ if(i == 1) {
+ checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums)
+ assert(checksumValues.length > 0)
+ assert(checksumValues.forall(_ > 0))
+
+ aggregatedChecksumValue = writer.getAggregatedChecksumValue()
+ assert(aggregatedChecksumValue != 0)
+ } else {
+ assert(checksumValues.sameElements(
+ getRowBasedChecksumValues(writer.getRowBasedChecksums)))
+ assert(aggregatedChecksumValue == writer.getAggregatedChecksumValue())
+ }
+ }
+ }
}
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index 99402abb16ca..9d4b0625f762 100644
---
a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++
b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.sort
+import scala.util.Random
+
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.Mockito._
@@ -50,6 +52,7 @@ class SortShuffleWriterSuite
private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
private val serializer = new JavaSerializer(conf)
private var shuffleExecutorComponents: ShuffleExecutorComponents = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency:
ShuffleDependency[Int, Int, Int] = _
private val partitioner = new Partitioner() {
def numPartitions = numMaps
@@ -60,13 +63,9 @@ class SortShuffleWriterSuite
super.beforeEach()
MockitoAnnotations.openMocks(this).close()
shuffleHandle = {
- val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
- when(dependency.partitioner).thenReturn(partitioner)
- when(dependency.serializer).thenReturn(serializer)
- when(dependency.aggregator).thenReturn(None)
- when(dependency.keyOrdering).thenReturn(None)
new BaseShuffleHandle(shuffleId, dependency)
}
+ resetDependency(rowBasedChecksumEnabled = false)
shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents(
conf, blockManager, shuffleBlockResolver)
}
@@ -79,6 +78,21 @@ class SortShuffleWriterSuite
}
}
+ private def resetDependency(rowBasedChecksumEnabled: Boolean): Unit = {
+ reset(dependency)
+ when(dependency.partitioner).thenReturn(partitioner)
+ when(dependency.serializer).thenReturn(serializer)
+ when(dependency.aggregator).thenReturn(None)
+ when(dependency.keyOrdering).thenReturn(None)
+ val checksumSize = if (rowBasedChecksumEnabled) {
+ numMaps
+ } else {
+ 0
+ }
+ val rowBasedChecksums = createPartitionRowBasedChecksums(checksumSize)
+ when(dependency.rowBasedChecksums).thenReturn(rowBasedChecksums)
+ }
+
test("write empty iterator") {
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val writer = new SortShuffleWriter[Int, Int, Int](
@@ -114,6 +128,44 @@ class SortShuffleWriterSuite
assert(records.size === writeMetrics.recordsWritten)
}
+ test("Row-based checksums are independent of input row order") {
+ val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+ val records: List[(Int, Int)] = List(
+ (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
+ (2, 2), (2, 3), (2, 4), (2, 5), (2, 6),
+ (3, 3), (3, 4), (3, 5), (3, 6), (3, 7),
+ (4, 4), (4, 5), (4, 6), (4, 7), (4, 8),
+ (5, 5), (5, 6), (5, 7), (5, 8), (5, 9))
+
+ var checksumValues : Array[Long] = Array[Long]()
+ var aggregatedChecksumValue = 0L
+ for (i <- 1 to 100) {
+ resetDependency(rowBasedChecksumEnabled = true)
+ val writer = new SortShuffleWriter[Int, Int, Int](
+ shuffleHandle,
+ mapId = 2,
+ context,
+ context.taskMetrics().shuffleWriteMetrics,
+ new LocalDiskShuffleExecutorComponents(
+ conf, shuffleBlockResolver._blockManager, shuffleBlockResolver))
+ writer.write(Random.shuffle(records).iterator)
+ if(i == 1) {
+ checksumValues = getRowBasedChecksumValues(writer.getRowBasedChecksums)
+ assert(checksumValues.length > 0)
+ assert(checksumValues.forall(_ > 0))
+
+ aggregatedChecksumValue = writer.getAggregatedChecksumValue
+ assert(aggregatedChecksumValue != 0)
+ } else {
+ assert(checksumValues.sameElements(
+ getRowBasedChecksumValues(writer.getRowBasedChecksums)))
+ assert(aggregatedChecksumValue == writer.getAggregatedChecksumValue)
+ }
+ writer.stop(success = true)
+ }
+ }
+
Seq((true, false, false),
(true, true, false),
(true, false, true),
@@ -141,6 +193,7 @@ class SortShuffleWriterSuite
when(dependency.serializer).thenReturn(serializer)
when(dependency.aggregator).thenReturn(aggregator)
when(dependency.keyOrdering).thenReturn(order)
+ when(dependency.rowBasedChecksums).thenReturn(Array.empty)
new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency)
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala
new file mode 100644
index 000000000000..2be675070eb1
--- /dev/null
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowChecksum.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.shuffle.checksum.RowBasedChecksum
+
+/**
+ * A concrete implementation of RowBasedChecksum for computing checksum for
UnsafeRow.
+ * The checksum for each row is computed by first casting or converting the
baseObject
+ * in the UnsafeRow to a byte array, and then computing the checksum for the
byte array.
+ *
+ * Note that the input key is ignored in the checksum computation. As the
Spark shuffle
+ * currently uses a PartitionIdPassthrough partitioner, the keys are already
the partition
+ * IDs for sending the data, and they are the same for all rows in the same
partition.
+ */
+class UnsafeRowChecksum extends RowBasedChecksum() {
+
+ override protected def calculateRowChecksum(key: Any, value: Any): Long = {
+ assert(
+ value.isInstanceOf[UnsafeRow],
+ "Expecting UnsafeRow but got " + value.getClass.getName)
+
+ // Casts or converts the baseObject in UnsafeRow to a byte array.
+ val unsafeRow = value.asInstanceOf[UnsafeRow]
+ XXH64.hashUnsafeBytes(
+ unsafeRow.getBaseObject,
+ unsafeRow.getBaseOffset,
+ unsafeRow.getSizeInBytes,
+ 0
+ )
+ }
+}
+
+object UnsafeRowChecksum {
+ def createUnsafeRowChecksums(numPartitions: Int): Array[RowBasedChecksum] = {
+ Array.tabulate(numPartitions)(_ => new UnsafeRowChecksum())
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5a99814c8cc8..87a4664aaa75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -875,6 +875,21 @@ object SQLConf {
.checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be
positive")
.createWithDefault(200)
+ val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED =
+ buildConf("spark.sql.shuffle.orderIndependentChecksum.enabled")
+ .doc("Whether to calculate order independent checksum for the shuffle
data or not. If " +
+ "enabled, Spark will calculate a checksum that is independent of the
input row order for " +
+ "each mapper and returns the checksums from executors to driver. This
is different from " +
+ "the checksum computed when spark.shuffle.checksum.enabled is enabled
which is sensitive " +
+ "to shuffle data ordering to detect file corruption. While this
checksum will be the " +
+ "same even if the shuffle row order changes and it is used to detect
whether different " +
+ "task attempts of the same partition produce different output data or
not (same set of " +
+ "keyValue pairs). In case the output data has changed across retries,
Spark will need to " +
+ "retry all tasks of the consumer stages to avoid correctness issues.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize")
.internal()
@@ -6621,6 +6636,9 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
}
}
+ def shuffleOrderIndependentChecksumEnabled: Boolean =
+ getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)
+
def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS)
def objectLevelCollationsEnabled: Boolean =
getConf(OBJECT_LEVEL_COLLATIONS_ENABLED)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 3b8fa821eac7..9c86bbb606a5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -30,7 +30,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter,
ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow, UnsafeRowChecksum}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -480,12 +480,19 @@ object ShuffleExchangeExec {
// Now, we manually create a ShuffleDependency. Because pairs in
rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the
expected range
// [0, part.numPartitions - 1]. The partitioner of this is a
PartitionIdPassthrough.
+ val checksumSize =
+ if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) {
+ part.numPartitions
+ } else {
+ 0
+ }
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
- shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))
+ shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
+ rowBasedChecksums =
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize))
dependency
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
new file mode 100644
index 000000000000..0fe660312210
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite}
+import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
+
+class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils {
+ override def spark: SparkSession = SparkSession.builder()
+ .master("local")
+ .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value =
true)
+ .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5)
+ .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key,
value = false)
+ .getOrCreate()
+
+ override def afterAll(): Unit = {
+ // This suite should not interfere with the other test suites.
+ SparkSession.getActiveSession.foreach(_.stop())
+ SparkSession.clearActiveSession()
+ SparkSession.getDefaultSession.foreach(_.stop())
+ SparkSession.clearDefaultSession()
+ }
+
+ test("Propagate checksum from executor to driver") {
+ assert(spark.sparkContext.conf
+ .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true")
+
assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") ==
"true")
+ assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism")
== "5")
+ assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5")
+
assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled")
+ == "false")
+
assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled")
== "false")
+
+ withTable("t") {
+ spark.range(1000).repartition(10).write.mode("overwrite").
+ saveAsTable("t")
+ }
+
+ val shuffleStatuses = spark.sparkContext.env.mapOutputTracker.
+ asInstanceOf[MapOutputTrackerMaster].shuffleStatuses
+ assert(shuffleStatuses.size == 1)
+
+ val mapStatuses = shuffleStatuses(0).mapStatuses
+ assert(mapStatuses.length == 5)
+ assert(mapStatuses.forall(_.checksumValue != 0))
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala
new file mode 100644
index 000000000000..07941ad62633
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowChecksumSuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowChecksum}
+import org.apache.spark.sql.types._
+
+class UnsafeRowChecksumSuite extends SparkFunSuite {
+ private val schema = new StructType().add("value", IntegerType)
+ private val toUnsafeRow = ExpressionEncoder(schema).createSerializer()
+
+ private val schemaComplex = new StructType()
+ .add("stringCol", StringType)
+ .add("doubleCol", DoubleType)
+ .add("longCol", LongType)
+ .add("int32Col", IntegerType)
+ .add("int16Col", ShortType)
+ .add("int8Col", ByteType)
+ .add("boolCol", BooleanType)
+ private val toUnsafeRowComplex =
ExpressionEncoder(schemaComplex).createSerializer()
+
+ private def setUnsafeRowValue(
+ stringCol: String,
+ doubleCol: Double,
+ longCol: Long,
+ int32Col: Int,
+ int16Col: Short,
+ int8Col: Byte,
+ boolCol: Boolean,
+ unsafeRowOffheap: UnsafeRow): Unit = {
+ unsafeRowOffheap.writeFieldTo(0, ByteBuffer.wrap(stringCol.getBytes))
+ unsafeRowOffheap.setDouble(1, doubleCol)
+ unsafeRowOffheap.setLong(2, longCol)
+ unsafeRowOffheap.setInt(3, int32Col)
+ unsafeRowOffheap.setShort(4, int16Col)
+ unsafeRowOffheap.setByte(5, int8Col)
+ unsafeRowOffheap.setBoolean(6, boolCol)
+ }
+
+ test("Non-UnsafeRow value should fail") {
+ val rowBasedChecksum = new UnsafeRowChecksum()
+ rowBasedChecksum.update(1, Long.box(20))
+ // We fail to compute the checksum, and getValue returns 0.
+ assert(rowBasedChecksum.getValue == 0)
+ }
+
+ test("Two identical rows should not have a checksum of zero") {
+ val rowBasedChecksum = new UnsafeRowChecksum()
+ assert(rowBasedChecksum.getValue == 0)
+
+ // Updates the checksum with one row.
+ rowBasedChecksum.update(1, toUnsafeRow(Row(20)))
+ assert(rowBasedChecksum.getValue == -9094624449814316735L)
+
+ // Updates the checksum with the same row again, since we mix the final
xor and sum
+ // of the row-based checksum, the result would not be 0.
+ rowBasedChecksum.update(1, toUnsafeRow(Row(20)))
+ assert(rowBasedChecksum.getValue == -1240577858172431653L)
+ }
+
+ test("The checksum is independent of row order - two rows") {
+ val rowBasedChecksum1 = new UnsafeRowChecksum()
+ val rowBasedChecksum2 = new UnsafeRowChecksum()
+ assert(rowBasedChecksum1.getValue == 0)
+ assert(rowBasedChecksum2.getValue == 0)
+
+ rowBasedChecksum1.update(1, toUnsafeRow(Row(20)))
+ rowBasedChecksum2.update(1, toUnsafeRow(Row(40)))
+ assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue)
+
+ rowBasedChecksum1.update(2, toUnsafeRow(Row(40)))
+ rowBasedChecksum2.update(2, toUnsafeRow(Row(20)))
+ assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue)
+
+ assert(rowBasedChecksum1.getValue != 0)
+ assert(rowBasedChecksum2.getValue != 0)
+ }
+
+ test("The checksum is independent of row order - multiple rows") {
+ val rowBasedChecksum1 = new UnsafeRowChecksum()
+ val rowBasedChecksum2 = new UnsafeRowChecksum()
+ assert(rowBasedChecksum1.getValue == 0)
+ assert(rowBasedChecksum2.getValue == 0)
+
+ rowBasedChecksum1.update(1, toUnsafeRow(Row(20)))
+ rowBasedChecksum2.update(1, toUnsafeRow(Row(100)))
+ assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue)
+
+ rowBasedChecksum1.update(2, toUnsafeRow(Row(40)))
+ rowBasedChecksum2.update(2, toUnsafeRow(Row(80)))
+ assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue)
+
+ rowBasedChecksum1.update(3, toUnsafeRow(Row(60)))
+ rowBasedChecksum2.update(3, toUnsafeRow(Row(60)))
+ assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue)
+
+ rowBasedChecksum1.update(4, toUnsafeRow(Row(80)))
+ rowBasedChecksum2.update(4, toUnsafeRow(Row(40)))
+ assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue)
+
+ rowBasedChecksum1.update(5, toUnsafeRow(Row(100)))
+ rowBasedChecksum2.update(5, toUnsafeRow(Row(20)))
+ assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue)
+
+ assert(rowBasedChecksum1.getValue != 0)
+ assert(rowBasedChecksum2.getValue != 0)
+ }
+
+ test("The checksum is independent of row order - complex rows") {
+ val rowBasedChecksum1 = new UnsafeRowChecksum()
+ val rowBasedChecksum2 = new UnsafeRowChecksum()
+ assert(rowBasedChecksum1.getValue == 0)
+ assert(rowBasedChecksum2.getValue == 0)
+
+ rowBasedChecksum1.update(1, toUnsafeRowComplex(Row(
+ "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true)))
+ rowBasedChecksum2.update(1, toUnsafeRowComplex(Row(
+ "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte,
false)))
+ assert(rowBasedChecksum1.getValue != rowBasedChecksum2.getValue)
+
+ rowBasedChecksum1.update(2, toUnsafeRowComplex(Row(
+ "Some other string", 10.88, 20000L, 2000, 200.toShort, 20.toByte,
false)))
+ rowBasedChecksum2.update(2, toUnsafeRowComplex(Row(
+ "Some string", 0.99, 10000L, 1000, 100.toShort, 10.toByte, true)))
+ assert(rowBasedChecksum1.getValue == rowBasedChecksum2.getValue)
+
+ assert(rowBasedChecksum1.getValue != 0)
+ assert(rowBasedChecksum2.getValue != 0)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]