This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new b119334b8439 [SPARK-55619][SQL][4.1] Fix custom metrics in case of
coalesced partitions
b119334b8439 is described below
commit b119334b8439a21eff32819beb81749c0488f52c
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Feb 21 16:27:03 2026 -0800
[SPARK-55619][SQL][4.1] Fix custom metrics in case of coalesced partitions
### What changes were proposed in this pull request?
Replace `PartitionMetricCallback` with a `ConcurrentHashMap` keyed by task
attempt ID to correctly track reader state across multiple `compute()` calls
when `DataSourceRDD` is coalesced. The completion listener is registered only
once per task attempt, and metrics are flushed and carried forward between
readers as partitions are advanced.
### Why are the changes needed?
When `DataSourceRDD` is coalesced (e.g., via `.coalesce(1)`), `compute()`
gets called multiple times per task, which causes the custom metrics incorrect.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Sonnet 4.6
Closes #54407 from viirya/SPARK-55619-branch-4.1.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
.../execution/datasources/v2/DataSourceRDD.scala | 95 ++++++++++++----------
.../connector/KeyGroupedPartitioningSuite.scala | 15 ++++
2 files changed, 69 insertions(+), 41 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index fbf5c06fe051..19a057c72506 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.v2
+import java.util.concurrent.ConcurrentHashMap
+
import scala.language.existentials
import org.apache.spark._
@@ -24,7 +26,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -34,6 +35,19 @@ import org.apache.spark.util.ArrayImplicits._
class DataSourceRDDPartition(val index: Int, val inputPartitions:
Seq[InputPartition])
extends Partition with Serializable
+/**
+ * Holds the state for a reader in a task, used by the completion listener to
access the most
+ * recently created reader and iterator for final metrics updates and cleanup.
+ *
+ * When `compute()` is called multiple times for the same task (e.g., when
DataSourceRDD is
+ * coalesced), this state is updated on each call to track the most recent
reader. The task
+ * completion listener then uses this most recent reader for final cleanup and
metrics reporting.
+ *
+ * @param reader The partition reader
+ * @param iterator The metrics iterator wrapping the reader
+ */
+private case class ReaderState(reader: PartitionReader[_], iterator:
MetricsIterator[_])
+
// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an
`RDD[ColumnarBatch]` for
// columnar scan.
class DataSourceRDD(
@@ -44,6 +58,11 @@ class DataSourceRDD(
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {
+ // Map from task attempt ID to the most recently created ReaderState for
that task.
+ // When compute() is called multiple times for the same task (due to
coalescing), the map entry
+ // is updated each time so the completion listener always closes the last
reader.
+ @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long,
ReaderState]()
+
override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
case (inputPartitions, index) => new DataSourceRDDPartition(index,
inputPartitions)
@@ -56,20 +75,34 @@ class DataSourceRDD(
}
override def compute(split: Partition, context: TaskContext):
Iterator[InternalRow] = {
+ val taskAttemptId = context.taskAttemptId()
+
+ // Add completion listener only once per task attempt. When compute() is
called a second time
+ // for the same task (e.g., due to coalescing), the first call will have
already put a
+ // ReaderState into taskReaderStates, so containsKey returns true and we
skip this block.
+ if (!taskReaderStates.containsKey(taskAttemptId)) {
+ context.addTaskCompletionListener[Unit] { ctx =>
+ // In case of early stopping before consuming the entire iterator,
+ // we need to do one more metric update at the end of the task.
+ try {
+ val readerState = taskReaderStates.get(ctx.taskAttemptId())
+ if (readerState != null) {
+ CustomMetrics.updateMetrics(
+ readerState.reader.currentMetricsValues.toImmutableArraySeq,
customMetrics)
+ readerState.iterator.forceUpdateMetrics()
+ readerState.reader.close()
+ }
+ } finally {
+ taskReaderStates.remove(ctx.taskAttemptId())
+ }
+ }
+ }
val iterator = new Iterator[Object] {
private val inputPartitions = castPartition(split).inputPartitions
private var currentIter: Option[Iterator[Object]] = None
private var currentIndex: Int = 0
- private val partitionMetricCallback = new
PartitionMetricCallback(customMetrics)
-
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
- context.addTaskCompletionListener[Unit] { _ =>
- partitionMetricCallback.execute()
- }
-
override def hasNext: Boolean = currentIter.exists(_.hasNext) ||
advanceToNextIter()
override def next(): Object = {
@@ -97,9 +130,18 @@ class DataSourceRDD(
(iter, rowReader)
}
- // Once we advance to the next partition, update the metric callback
for early finish
- val previousMetrics = partitionMetricCallback.advancePartition(iter,
reader)
- previousMetrics.foreach(reader.initMetricsValues)
+ // Flush metrics and close the previous reader before advancing to
the next one.
+ // Pass the accumulated metrics to the new reader so they carry
forward correctly.
+ val prevState = taskReaderStates.get(taskAttemptId)
+ if (prevState != null) {
+ val metrics = prevState.reader.currentMetricsValues
+ CustomMetrics.updateMetrics(metrics.toImmutableArraySeq,
customMetrics)
+ reader.initMetricsValues(metrics)
+ prevState.reader.close()
+ }
+
+ // Update the map so the completion listener always references the
latest reader.
+ taskReaderStates.put(taskAttemptId, ReaderState(reader, iter))
currentIter = Some(iter)
hasNext
@@ -115,35 +157,6 @@ class DataSourceRDD(
}
}
-private class PartitionMetricCallback
- (customMetrics: Map[String, SQLMetric]) {
- private var iter: MetricsIterator[_] = null
- private var reader: PartitionReader[_] = null
-
- def advancePartition(
- iter: MetricsIterator[_],
- reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
- val metrics = execute()
-
- this.iter = iter
- this.reader = reader
-
- metrics
- }
-
- def execute(): Option[Array[CustomTaskMetric]] = {
- if (iter != null && reader != null) {
- val metrics = reader.currentMetricsValues
- CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
- iter.forceUpdateMetrics()
- reader.close()
- Some(metrics)
- } else {
- None
- }
- }
-}
-
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 409e3fa92240..8a65cb623f6e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -2840,6 +2840,21 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
assert(metrics("number of rows read") == "3")
}
+ test("SPARK-55619: Custom metrics of coalesced partitions") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(2, 'bb', 10.0, cast('2021-01-01' as timestamp))")
+
+ val metrics = runAndFetchMetrics {
+ val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
+ df.collect()
+ }
+ assert(metrics("number of rows read") == "2")
+ }
+
test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
"are less than cluster keys") {
withSQLConf(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]