This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 435f133a023f [SPARK-55619][SQL][3.5] Fix custom metrics in case of
coalesced partitions
435f133a023f is described below
commit 435f133a023f6646bb00b6300ec01d59932fb275
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Feb 21 10:42:11 2026 -0800
[SPARK-55619][SQL][3.5] Fix custom metrics in case of coalesced partitions
### What changes were proposed in this pull request?
Replace `PartitionMetricCallback` with a `ConcurrentHashMap` keyed by task
attempt ID to correctly track reader state across multiple `compute()` calls
when `DataSourceRDD` is coalesced. The completion listener is registered only
once per task attempt, and metrics are flushed and carried forward between
readers as partitions are advanced.
### Why are the changes needed?
When `DataSourceRDD` is coalesced (e.g., via `.coalesce(1)`), `compute()`
gets called multiple times per task, which causes the custom metrics incorrect.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Sonnet 4.6
Closes #54409 from viirya/SPARK-55619-branch-3.5.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
.../execution/datasources/v2/DataSourceRDD.scala | 94 ++++++++++++----------
.../connector/KeyGroupedPartitioningSuite.scala | 15 ++++
2 files changed, 68 insertions(+), 41 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 95b3099403f5..fd668d8070cd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.v2
+import java.util.concurrent.ConcurrentHashMap
+
import scala.language.existentials
import org.apache.spark._
@@ -24,7 +26,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -33,6 +34,19 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
class DataSourceRDDPartition(val index: Int, val inputPartitions:
Seq[InputPartition])
extends Partition with Serializable
+/**
+ * Holds the state for a reader in a task, used by the completion listener to
access the most
+ * recently created reader and iterator for final metrics updates and cleanup.
+ *
+ * When `compute()` is called multiple times for the same task (e.g., when
DataSourceRDD is
+ * coalesced), this state is updated on each call to track the most recent
reader. The task
+ * completion listener then uses this most recent reader for final cleanup and
metrics reporting.
+ *
+ * @param reader The partition reader
+ * @param iterator The metrics iterator wrapping the reader
+ */
+private case class ReaderState(reader: PartitionReader[_], iterator:
MetricsIterator[_])
+
// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an
`RDD[ColumnarBatch]` for
// columnar scan.
class DataSourceRDD(
@@ -43,6 +57,11 @@ class DataSourceRDD(
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {
+ // Map from task attempt ID to the most recently created ReaderState for
that task.
+ // When compute() is called multiple times for the same task (due to
coalescing), the map entry
+ // is updated each time so the completion listener always closes the last
reader.
+ @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long,
ReaderState]()
+
override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
case (inputPartitions, index) => new DataSourceRDDPartition(index,
inputPartitions)
@@ -55,20 +74,33 @@ class DataSourceRDD(
}
override def compute(split: Partition, context: TaskContext):
Iterator[InternalRow] = {
+ val taskAttemptId = context.taskAttemptId()
+
+ // Add completion listener only once per task attempt. When compute() is
called a second time
+ // for the same task (e.g., due to coalescing), the first call will have
already put a
+ // ReaderState into taskReaderStates, so containsKey returns true and we
skip this block.
+ if (!taskReaderStates.containsKey(taskAttemptId)) {
+ context.addTaskCompletionListener[Unit] { ctx =>
+ // In case of early stopping before consuming the entire iterator,
+ // we need to do one more metric update at the end of the task.
+ try {
+ val readerState = taskReaderStates.get(ctx.taskAttemptId())
+ if (readerState != null) {
+
CustomMetrics.updateMetrics(readerState.reader.currentMetricsValues,
customMetrics)
+ readerState.iterator.forceUpdateMetrics()
+ readerState.reader.close()
+ }
+ } finally {
+ taskReaderStates.remove(ctx.taskAttemptId())
+ }
+ }
+ }
val iterator = new Iterator[Object] {
private val inputPartitions = castPartition(split).inputPartitions
private var currentIter: Option[Iterator[Object]] = None
private var currentIndex: Int = 0
- private val partitionMetricCallback = new
PartitionMetricCallback(customMetrics)
-
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
- context.addTaskCompletionListener[Unit] { _ =>
- partitionMetricCallback.execute()
- }
-
override def hasNext: Boolean = currentIter.exists(_.hasNext) ||
advanceToNextIter()
override def next(): Object = {
@@ -96,9 +128,18 @@ class DataSourceRDD(
(iter, rowReader)
}
- // Once we advance to the next partition, update the metric callback
for early finish
- val previousMetrics = partitionMetricCallback.advancePartition(iter,
reader)
- previousMetrics.foreach(reader.initMetricsValues)
+ // Flush metrics and close the previous reader before advancing to
the next one.
+ // Pass the accumulated metrics to the new reader so they carry
forward correctly.
+ val prevState = taskReaderStates.get(taskAttemptId)
+ if (prevState != null) {
+ val metrics = prevState.reader.currentMetricsValues
+ CustomMetrics.updateMetrics(metrics, customMetrics)
+ reader.initMetricsValues(metrics)
+ prevState.reader.close()
+ }
+
+ // Update the map so the completion listener always references the
latest reader.
+ taskReaderStates.put(taskAttemptId, ReaderState(reader, iter))
currentIter = Some(iter)
hasNext
@@ -114,35 +155,6 @@ class DataSourceRDD(
}
}
-private class PartitionMetricCallback
- (customMetrics: Map[String, SQLMetric]) {
- private var iter: MetricsIterator[_] = null
- private var reader: PartitionReader[_] = null
-
- def advancePartition(
- iter: MetricsIterator[_],
- reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
- val metrics = execute()
-
- this.iter = iter
- this.reader = reader
-
- metrics
- }
-
- def execute(): Option[Array[CustomTaskMetric]] = {
- if (iter != null && reader != null) {
- val metrics = reader.currentMetricsValues
- CustomMetrics.updateMetrics(metrics, customMetrics)
- iter.forceUpdateMetrics()
- reader.close()
- Some(metrics)
- } else {
- None
- }
- }
-}
-
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 3bb8ea7542ea..7203cedf3ea7 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1244,4 +1244,19 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
}
assert(metrics("number of rows read") == "3")
}
+
+ test("SPARK-55619: Custom metrics of coalesced partitions") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(2, 'bb', 10.0, cast('2021-01-01' as timestamp))")
+
+ val metrics = runAndFetchMetrics {
+ val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
+ df.collect()
+ }
+ assert(metrics("number of rows read") == "2")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]