This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new b119334b8439 [SPARK-55619][SQL][4.1] Fix custom metrics in case of 
coalesced partitions
b119334b8439 is described below

commit b119334b8439a21eff32819beb81749c0488f52c
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat Feb 21 16:27:03 2026 -0800

    [SPARK-55619][SQL][4.1] Fix custom metrics in case of coalesced partitions
    
    ### What changes were proposed in this pull request?
    
    Replace `PartitionMetricCallback` with a `ConcurrentHashMap` keyed by task 
attempt ID to correctly track reader state across multiple `compute()` calls 
when `DataSourceRDD` is coalesced. The completion listener is registered only 
once per task attempt, and metrics are flushed and carried forward between 
readers as partitions are advanced.
    
    ### Why are the changes needed?
    
    When `DataSourceRDD` is coalesced (e.g., via `.coalesce(1)`), `compute()` 
gets called multiple times per task, which causes the custom metrics incorrect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Sonnet 4.6
    
    Closes #54407 from viirya/SPARK-55619-branch-4.1.
    
    Authored-by: Liang-Chi Hsieh <[email protected]>
    Signed-off-by: Liang-Chi Hsieh <[email protected]>
---
 .../execution/datasources/v2/DataSourceRDD.scala   | 95 ++++++++++++----------
 .../connector/KeyGroupedPartitioningSuite.scala    | 15 ++++
 2 files changed, 69 insertions(+), 41 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index fbf5c06fe051..19a057c72506 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
+import java.util.concurrent.ConcurrentHashMap
+
 import scala.language.existentials
 
 import org.apache.spark._
@@ -24,7 +26,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.metric.CustomTaskMetric
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -34,6 +35,19 @@ import org.apache.spark.util.ArrayImplicits._
 class DataSourceRDDPartition(val index: Int, val inputPartitions: 
Seq[InputPartition])
   extends Partition with Serializable
 
+/**
+ * Holds the state for a reader in a task, used by the completion listener to 
access the most
+ * recently created reader and iterator for final metrics updates and cleanup.
+ *
+ * When `compute()` is called multiple times for the same task (e.g., when 
DataSourceRDD is
+ * coalesced), this state is updated on each call to track the most recent 
reader. The task
+ * completion listener then uses this most recent reader for final cleanup and 
metrics reporting.
+ *
+ * @param reader The partition reader
+ * @param iterator The metrics iterator wrapping the reader
+ */
+private case class ReaderState(reader: PartitionReader[_], iterator: 
MetricsIterator[_])
+
 // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an 
`RDD[ColumnarBatch]` for
 // columnar scan.
 class DataSourceRDD(
@@ -44,6 +58,11 @@ class DataSourceRDD(
     customMetrics: Map[String, SQLMetric])
   extends RDD[InternalRow](sc, Nil) {
 
+  // Map from task attempt ID to the most recently created ReaderState for 
that task.
+  // When compute() is called multiple times for the same task (due to 
coalescing), the map entry
+  // is updated each time so the completion listener always closes the last 
reader.
+  @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, 
ReaderState]()
+
   override protected def getPartitions: Array[Partition] = {
     inputPartitions.zipWithIndex.map {
       case (inputPartitions, index) => new DataSourceRDDPartition(index, 
inputPartitions)
@@ -56,20 +75,34 @@ class DataSourceRDD(
   }
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
+    val taskAttemptId = context.taskAttemptId()
+
+    // Add completion listener only once per task attempt. When compute() is 
called a second time
+    // for the same task (e.g., due to coalescing), the first call will have 
already put a
+    // ReaderState into taskReaderStates, so containsKey returns true and we 
skip this block.
+    if (!taskReaderStates.containsKey(taskAttemptId)) {
+      context.addTaskCompletionListener[Unit] { ctx =>
+        // In case of early stopping before consuming the entire iterator,
+        // we need to do one more metric update at the end of the task.
+        try {
+          val readerState = taskReaderStates.get(ctx.taskAttemptId())
+          if (readerState != null) {
+            CustomMetrics.updateMetrics(
+              readerState.reader.currentMetricsValues.toImmutableArraySeq, 
customMetrics)
+            readerState.iterator.forceUpdateMetrics()
+            readerState.reader.close()
+          }
+        } finally {
+          taskReaderStates.remove(ctx.taskAttemptId())
+        }
+      }
+    }
 
     val iterator = new Iterator[Object] {
       private val inputPartitions = castPartition(split).inputPartitions
       private var currentIter: Option[Iterator[Object]] = None
       private var currentIndex: Int = 0
 
-      private val partitionMetricCallback = new 
PartitionMetricCallback(customMetrics)
-
-      // In case of early stopping before consuming the entire iterator,
-      // we need to do one more metric update at the end of the task.
-      context.addTaskCompletionListener[Unit] { _ =>
-        partitionMetricCallback.execute()
-      }
-
       override def hasNext: Boolean = currentIter.exists(_.hasNext) || 
advanceToNextIter()
 
       override def next(): Object = {
@@ -97,9 +130,18 @@ class DataSourceRDD(
             (iter, rowReader)
           }
 
-          // Once we advance to the next partition, update the metric callback 
for early finish
-          val previousMetrics = partitionMetricCallback.advancePartition(iter, 
reader)
-          previousMetrics.foreach(reader.initMetricsValues)
+          // Flush metrics and close the previous reader before advancing to 
the next one.
+          // Pass the accumulated metrics to the new reader so they carry 
forward correctly.
+          val prevState = taskReaderStates.get(taskAttemptId)
+          if (prevState != null) {
+            val metrics = prevState.reader.currentMetricsValues
+            CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, 
customMetrics)
+            reader.initMetricsValues(metrics)
+            prevState.reader.close()
+          }
+
+          // Update the map so the completion listener always references the 
latest reader.
+          taskReaderStates.put(taskAttemptId, ReaderState(reader, iter))
 
           currentIter = Some(iter)
           hasNext
@@ -115,35 +157,6 @@ class DataSourceRDD(
   }
 }
 
-private class PartitionMetricCallback
-    (customMetrics: Map[String, SQLMetric]) {
-  private var iter: MetricsIterator[_] = null
-  private var reader: PartitionReader[_] = null
-
-  def advancePartition(
-      iter: MetricsIterator[_],
-      reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
-    val metrics = execute()
-
-    this.iter = iter
-    this.reader = reader
-
-    metrics
-  }
-
-  def execute(): Option[Array[CustomTaskMetric]] = {
-    if (iter != null && reader != null) {
-      val metrics = reader.currentMetricsValues
-      CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
-      iter.forceUpdateMetrics()
-      reader.close()
-      Some(metrics)
-    } else {
-      None
-    }
-  }
-}
-
 private class PartitionIterator[T](
     reader: PartitionReader[T],
     customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 409e3fa92240..8a65cb623f6e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -2840,6 +2840,21 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     assert(metrics("number of rows read") == "3")
   }
 
+  test("SPARK-55619: Custom metrics of coalesced partitions") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(2, 'bb', 10.0, cast('2021-01-01' as timestamp))")
+
+    val metrics = runAndFetchMetrics {
+      val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
+      df.collect()
+    }
+    assert(metrics("number of rows read") == "2")
+  }
+
   test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
     "are less than cluster keys") {
     withSQLConf(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to