This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 115a29c5d130 [SPARK-55302][SQL][3.5] Fix custom metrics in case of
`KeyGroupedPartitioning`
115a29c5d130 is described below
commit 115a29c5d13077924ef3989ff3d684554c25d38c
Author: Peter Toth <[email protected]>
AuthorDate: Fri Feb 20 14:04:05 2026 -0800
[SPARK-55302][SQL][3.5] Fix custom metrics in case of
`KeyGroupedPartitioning`
### What changes were proposed in this pull request?
This PR adds a new `initMetricsValues()` method to `PartitionReader` so as
to initialize custom metrics returned by `currentMetricsValues()`. In case of
`KeyGroupedPartitioning` multiple input partitions are grouped and so multiple
`PartitionReader` belong to one output partition. A `PartitionReader` needs to
be initialized with metrics calculated by the previous `PartitionReader` of the
same partition group to calculate the right value.
### Why are the changes needed?
To calculate custom metrics correctly.
### Does this PR introduce _any_ user-facing change?
It fixes metrics calculation.
### How was this patch tested?
New UT is added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54402 from peter-toth/SPARK-55302-fix-kgp-custom-metrics-3.5.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/connector/read/PartitionReader.java | 9 ++++++
.../sql/connector/catalog/InMemoryBaseTable.scala | 32 ++++++++++++++++++++--
.../execution/datasources/v2/DataSourceRDD.scala | 20 ++++++++++----
.../connector/KeyGroupedPartitioningSuite.scala | 18 ++++++++++++
.../datasources/InMemoryTableMetricSuite.scala | 18 ++----------
.../apache/spark/sql/test/SharedSparkSession.scala | 23 ++++++++++++++++
6 files changed, 97 insertions(+), 23 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
index 5286bbf9f85a..c12bc14a49c4 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
@@ -58,4 +58,13 @@ public interface PartitionReader<T> extends Closeable {
CustomTaskMetric[] NO_METRICS = {};
return NO_METRICS;
}
+
+ /**
+ * Sets the initial value of metrics before fetching any data from the
reader. This is called
+ * when multiple {@link PartitionReader}s are grouped into one partition in
case of
+ * {@link
org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and
the reader
+ * is initialized with the metrics returned by the previous reader that
belongs to the same
+ * partition. By default, this method does nothing.
+ */
+ default void initMetricsValues(CustomTaskMetric[] metrics) {}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index a309db341d8e..d9641cd7f72d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -31,7 +31,7 @@ import
org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow,
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
import org.apache.spark.sql.connector.distributions.{Distribution,
Distributions}
import org.apache.spark.sql.connector.expressions._
-import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
+import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric,
CustomTaskMetric}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics,
Histogram, HistogramBin}
import
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
Partitioning, UnknownPartitioning}
@@ -430,6 +430,10 @@ abstract class InMemoryBaseTable(
}
new BufferedRowsReaderFactory(metadataColumns.toSeq, nonMetadataColumns,
tableSchema)
}
+
+ override def supportedCustomMetrics(): Array[CustomMetric] = {
+ Array(new RowsReadCustomMetric)
+ }
}
case class InMemoryBatchScan(
@@ -662,10 +666,13 @@ private class BufferedRowsReader(
}
private var index: Int = -1
+ private var rowsRead: Long = 0
override def next(): Boolean = {
index += 1
- index < partition.rows.length
+ val hasNext = index < partition.rows.length
+ if (hasNext) rowsRead += 1
+ hasNext
}
override def get(): InternalRow = {
@@ -701,6 +708,22 @@ private class BufferedRowsReader(
row.get(index, dt)
}
}
+
+ override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = {
+ metrics.foreach { m =>
+ m.name match {
+ case "rows_read" => rowsRead = m.value()
+ }
+ }
+ }
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ val metric = new CustomTaskMetric {
+ override def name(): String = "rows_read"
+ override def value(): Long = rowsRead
+ }
+ Array(metric)
+ }
}
private object BufferedRowsWriterFactory extends DataWriterFactory with
StreamingDataWriterFactory {
@@ -744,3 +767,8 @@ class InMemorySimpleCustomMetric extends CustomMetric {
s"in-memory rows: ${taskMetrics.sum}"
}
}
+
+class RowsReadCustomMetric extends CustomSumMetric {
+ override def name(): String = "rows_read"
+ override def description(): String = "number of rows read"
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 36872f232e7e..95b3099403f5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -96,7 +97,8 @@ class DataSourceRDD(
}
// Once we advance to the next partition, update the metric callback
for early finish
- partitionMetricCallback.advancePartition(iter, reader)
+ val previousMetrics = partitionMetricCallback.advancePartition(iter,
reader)
+ previousMetrics.foreach(reader.initMetricsValues)
currentIter = Some(iter)
hasNext
@@ -117,18 +119,26 @@ private class PartitionMetricCallback
private var iter: MetricsIterator[_] = null
private var reader: PartitionReader[_] = null
- def advancePartition(iter: MetricsIterator[_], reader: PartitionReader[_]):
Unit = {
- execute()
+ def advancePartition(
+ iter: MetricsIterator[_],
+ reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
+ val metrics = execute()
this.iter = iter
this.reader = reader
+
+ metrics
}
- def execute(): Unit = {
+ def execute(): Option[Array[CustomTaskMetric]] = {
if (iter != null && reader != null) {
- CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
+ val metrics = reader.currentMetricsValues
+ CustomMetrics.updateMetrics(metrics, customMetrics)
iter.forceUpdateMetrics()
reader.close()
+ Some(metrics)
+ } else {
+ None
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 6f40775ce242..3bb8ea7542ea 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1226,4 +1226,22 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
}
}
}
+
+ test("SPARK-55302: Custom metrics of grouped partitions") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+ val metrics = runAndFetchMetrics {
+ val df = sql(s"SELECT * FROM testcat.ns.$items")
+ val scans = collectScans(df.queryExecution.executedPlan)
+ assert(scans(0).inputRDD.partitions.length === 2, "items scan should
have 2 partition groups")
+ df.collect()
+ }
+ assert(metrics("number of rows read") == "3")
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
index 33e2fc46ccba..8798d678478e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources
import java.util.Collections
import org.scalatest.BeforeAndAfter
-import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.connector.catalog.{Identifier,
InMemoryTableCatalog}
@@ -42,7 +41,7 @@ class InMemoryTableMetricSuite
spark.sessionState.conf.clear()
}
- private def testMetricOnDSv2(func: String => Unit, checker: Map[Long,
String] => Unit): Unit = {
+ private def testMetricOnDSv2(func: String => Unit, checker: Map[String,
String] => Unit): Unit = {
withTable("testcat.table_name") {
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
@@ -54,21 +53,8 @@ class InMemoryTableMetricSuite
new StructType().add("i", "int"),
Array.empty[Transform], Collections.emptyMap[String, String])
- func("testcat.table_name")
+ val metrics = runAndFetchMetrics(func("testcat.table_name"))
- // Wait until the new execution is started and being tracked.
- eventually(timeout(10.seconds), interval(10.milliseconds)) {
- assert(statusStore.executionsCount() >= oldCount)
- }
-
- // Wait for listener to finish computing the metrics for the execution.
- eventually(timeout(10.seconds), interval(10.milliseconds)) {
- assert(statusStore.executionsList().nonEmpty &&
- statusStore.executionsList().last.metricValues != null)
- }
-
- val execId = statusStore.executionsList().last.executionId
- val metrics = statusStore.executionMetrics(execId)
checker(metrics)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index ed2e309fa075..c23bf4204f75 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -53,6 +53,29 @@ trait SharedSparkSession extends SQLTestUtils with
SharedSparkSessionBase {
doThreadPostAudit()
}
}
+
+ def runAndFetchMetrics(func: => Unit): Map[String, String] = {
+ val statusStore = spark.sharedState.statusStore
+ val oldCount = statusStore.executionsList().size
+
+ func
+
+ // Wait until the new execution is started and being tracked.
+ eventually(timeout(10.seconds), interval(10.milliseconds)) {
+ assert(statusStore.executionsCount() >= oldCount)
+ }
+
+ // Wait for listener to finish computing the metrics for the execution.
+ eventually(timeout(10.seconds), interval(10.milliseconds)) {
+ assert(statusStore.executionsList().nonEmpty &&
+ statusStore.executionsList().last.metricValues != null)
+ }
+
+ val exec = statusStore.executionsList().last
+ val execId = exec.executionId
+ val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId ->
metric.name }.toMap
+ statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) ->
v }
+ }
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]