This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 84b9848f9a69 [SPARK-51097][SS] Re-introduce RocksDB state store's last 
uploaded snapshot version instance metrics
84b9848f9a69 is described below

commit 84b9848f9a69c32faa9baf623a90bbb2f4939d04
Author: Zeyu Chen <zyc...@gmail.com>
AuthorDate: Sat Mar 15 07:02:26 2025 +0900

    [SPARK-51097][SS] Re-introduce RocksDB state store's last uploaded snapshot 
version instance metrics
    
    ### What changes were proposed in this pull request?
    
    SPARK-51097
    
    #50161 recently had to revert the changes in #49816 due to instance metrics 
showing up on SparkUI, causing excessive clutter. This PR aims to re-introduce 
the new instance metrics while incorporating the fix in #50157.
    
    The main difference between this and the original PR are:
    - Distinguishing metrics and instance metrics as a whole, so that SparkPlan 
does not accidentally pick up hundreds of instance metrics
      - Because of this change, some of the logic behind processing instance 
metrics were simplified so I cleaned those up. More detail in the section below.
    - Adding a test to the suite to verify that no nodes in the execution plan 
contain the instance metrics we just introduced. Since SparkPlan nodes store 
metrics as strings, distinguishing between metric types isn't possible in this 
context. Thus, this test will only check for the specific metric we just 
introduced.
    - Small edit to another test to verify instance metric behavior as well 
(see 
[here](https://github.com/apache/spark/pull/50195/files#diff-c7a07d8e111bfd75e5579874c6be9ed5afcbfae7411d663352dbbbb51427b084R127))
    
    Line-by-line:
    - [statefulOperators.scala line 
206-238](https://github.com/apache/spark/pull/50195/files#diff-da6ad0bc819dce994a16436fa0797bfc8484644b475227a04c2a3eb5927515f7R206-R238)
 is one of the main changes. Instead of adding instance metrics to all metrics, 
we lazy initialize the instance metrics alongside but keep the two SQLMetric 
maps separate. The instance metrics map also changed its mapping to use the 
metric objects instead, since we want to hold the configuration information as 
well and [...]
    - [statefulOperators.scala line 
349-379](https://github.com/apache/spark/pull/50195/files#diff-da6ad0bc819dce994a16436fa0797bfc8484644b475227a04c2a3eb5927515f7R349-R379)
 to use the new instance metric map instead of switching between string metric 
names and metric objects.
    - [statefulOperators.scala line 
474-486](https://github.com/apache/spark/pull/50195/files#diff-da6ad0bc819dce994a16436fa0797bfc8484644b475227a04c2a3eb5927515f7R474-R486)
 to initialize all metrics in `stateStoreInstanceMetrics` and remove the 
`stateStoreInstanceMetricObjects` method as we no longer needed to use separate 
map indexing for instance metrics.
    - [RocksDBStateStoreIntegrationSuite line 
523-561](https://github.com/apache/spark/pull/50195/files#diff-c7a07d8e111bfd75e5579874c6be9ed5afcbfae7411d663352dbbbb51427b084R523-R561)
 to verify instance metrics don't appear in SparkPlan
    - Small edit to a pre-existing test in RocksDBStateStoreIntegrationSuite 
line 127 to instead verify instance metrics still show up in progress reports
    
    Before the fix (note that metrics are sorted lexicographically):
    
![image](https://github.com/user-attachments/assets/f6ed3e44-7a38-403f-bb14-8e30ad835498)
    
    After including the fix:
    
![image](https://github.com/user-attachments/assets/12e16b39-094b-4e29-9fe5-15a1b4c4fac2)
    
    ### Why are the changes needed?
    
    From #49816:
    There's currently a lack of observability into state store specific 
maintenance information, notably metrics of the last snapshot version uploaded. 
This affects the ability to identify performance degradation issues behind 
maintenance tasks and more as described in 
[SPARK-51097](https://issues.apache.org/jira/browse/SPARK-51097).
    
    ### Does this PR introduce _any_ user-facing change?
    
    From #49816:
    There will be some new metrics displayed from StreamingQueryProgress:
    ```
    Streaming query made progress: {
      ...
      "stateOperators" : [ {
        ...
        "customMetrics" : {
          ...
          "SnapshotLastUploaded.partition_0_default" : 2,
          "SnapshotLastUploaded.partition_12_default" : 10,
          "SnapshotLastUploaded.partition_8_default" : 10,
          ...
        }
      } ],
      "sources" : ...,
      "sink" : ...
    }
    ```
    
    ### How was this patch tested?
    
    Five tests are added to RocksDBStateStoreIntegrationSuite:
    - The first four are identical to the tests from #49816, which verify 
metrics are properly updating using custom faulty stores and different types of 
stateful queries (deduplicate and join).
    - The fifth test goes through the generated execution plan for stateful 
queries to make sure no instance metric shows up.
    
    I additionally manually verified these metrics on SparkUI (refer to 
screenshots above with and without the fix)
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50195 from zecookiez/SPARK-51097-rocksdb-with-fix.
    
    Lead-authored-by: Zeyu Chen <zyc...@gmail.com>
    Co-authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  16 ++
 .../streaming/StreamingSymmetricHashJoinExec.scala |   7 +-
 .../sql/execution/streaming/state/RocksDB.scala    |  12 +-
 .../state/RocksDBStateStoreProvider.scala          |  17 +-
 .../sql/execution/streaming/state/StateStore.scala | 101 ++++++-
 .../state/SymmetricHashJoinStateManager.scala      |   4 +-
 .../execution/streaming/statefulOperators.scala    | 123 +++++++--
 .../RocksDBStateStoreCheckpointFormatV2Suite.scala |   3 +
 .../state/RocksDBStateStoreIntegrationSuite.scala  | 295 ++++++++++++++++++++-
 9 files changed, 544 insertions(+), 34 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ef213f115aa3..7a2bd12c868a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2233,6 +2233,19 @@ object SQLConf {
       .intConf
       .createWithDefault(10)
 
+  val STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT =
+    
buildConf("spark.sql.streaming.stateStore.numStateStoreInstanceMetricsToReport")
+      .internal()
+      .doc(
+        "Number of state store instance metrics included in streaming query 
progress messages " +
+        "per stateful operator. Instance metrics are selected based on 
metric-specific ordering " +
+        "to minimize noise in the progress report."
+      )
+      .version("4.1.0")
+      .intConf
+      .checkValue(k => k >= 0, "Must be greater than or equal to 0")
+      .createWithDefault(5)
+
   val STATE_STORE_FORMAT_VALIDATION_ENABLED =
     buildConf("spark.sql.streaming.stateStore.formatValidation.enabled")
       .internal()
@@ -5775,6 +5788,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def numStateStoreMaintenanceThreads: Int = 
getConf(NUM_STATE_STORE_MAINTENANCE_THREADS)
 
+  def numStateStoreInstanceMetricsToReport: Int =
+    getConf(STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+
   def stateStoreMaintenanceShutdownTimeout: Long = 
getConf(STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT)
 
   def stateStoreMinDeltasForSnapshot: Int = 
getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 5eab57f7372c..7c8ba260b88a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -224,7 +224,7 @@ case class StreamingSymmetricHashJoinExec(
 
   override def shortName: String = "symmetricHashJoin"
 
-  private val stateStoreNames =
+  override val stateStoreNames: Seq[String] =
     SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
 
   override def operatorStateMetadata(
@@ -527,9 +527,8 @@ case class StreamingSymmetricHashJoinExec(
           (leftSideJoiner.numUpdatedStateRows + 
rightSideJoiner.numUpdatedStateRows)
         numTotalStateRows += combinedMetrics.numKeys
         stateMemory += combinedMetrics.memoryUsedBytes
-        combinedMetrics.customMetrics.foreach { case (metric, value) =>
-          longMetric(metric.name) += value
-        }
+        setStoreCustomMetrics(combinedMetrics.customMetrics)
+        setStoreInstanceMetrics(combinedMetrics.instanceMetrics)
       }
 
       val stateStoreNames = 
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide);
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 5ff6ae6551ff..aa8e96bac046 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import java.util.Set
 import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, 
TimeUnit}
-import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.{mutable, Map}
@@ -147,6 +147,10 @@ class RocksDB(
   private val enableChangelogCheckpointing: Boolean = 
conf.enableChangelogCheckpointing
   @volatile protected var loadedVersion: Long = -1L   // -1 = nothing valid is 
loaded
 
+  // Can be updated by whichever thread uploaded a snapshot, which could be 
either task,
+  // maintenance, or both. -1 represents no version has ever been uploaded.
+  protected val lastUploadedSnapshotVersion: AtomicLong = new AtomicLong(-1L)
+
   // variables to manage checkpoint ID. Once a checkpointing finishes, it 
needs to return
   // `lastCommittedStateStoreCkptId` as the committed checkpointID, as well as
   // `lastCommitBasedStateStoreCkptId` as the checkpontID of the previous 
version that is based on.
@@ -1297,6 +1301,7 @@ class RocksDB(
       bytesCopied = fileManagerMetrics.bytesCopied,
       filesCopied = fileManagerMetrics.filesCopied,
       filesReused = fileManagerMetrics.filesReused,
+      lastUploadedSnapshotVersion = lastUploadedSnapshotVersion.get(),
       zipFileBytesUncompressed = fileManagerMetrics.zipFileBytesUncompressed,
       nativeOpsMetrics = nativeOpsMetrics)
   }
@@ -1465,6 +1470,8 @@ class RocksDB(
         log"with uniqueId: ${MDC(LogKeys.UUID, snapshot.uniqueId)} " +
         log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " +
         log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
+      // Compare and update with the version that was just uploaded.
+      lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, 
v))
     } finally {
       snapshot.close()
     }
@@ -1916,7 +1923,8 @@ case class RocksDBMetrics(
     bytesCopied: Long,
     filesReused: Long,
     zipFileBytesUncompressed: Option[Long],
-    nativeOpsMetrics: Map[String, Long]) {
+    nativeOpsMetrics: Map[String, Long],
+    lastUploadedSnapshotVersion: Long) {
   def json: String = Serialization.write(this)(RocksDBMetrics.format)
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index cd9fdb9469d6..ee11ae6cfae8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -316,14 +316,21 @@ private[sql] class RocksDBStateStoreProvider
         ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
           Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> 
bytes)).getOrElse(Map())
 
+        val stateStoreInstanceMetrics = Map[StateStoreInstanceMetric, Long](
+          CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED
+            .withNewId(id.partitionId, id.storeName) -> 
rocksDBMetrics.lastUploadedSnapshotVersion
+        )
+
         StateStoreMetrics(
           rocksDBMetrics.numUncommittedKeys,
           rocksDBMetrics.totalMemUsageBytes,
-          stateStoreCustomMetrics)
+          stateStoreCustomMetrics,
+          stateStoreInstanceMetrics
+        )
       } else {
         logInfo(log"Failed to collect metrics for 
store_id=${MDC(STATE_STORE_ID, id)} " +
           log"and version=${MDC(VERSION_NUM, version)}")
-        StateStoreMetrics(0, 0, Map.empty)
+        StateStoreMetrics(0, 0, Map.empty, Map.empty)
       }
     }
 
@@ -497,6 +504,8 @@ private[sql] class RocksDBStateStoreProvider
 
   override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = 
ALL_CUSTOM_METRICS
 
+  override def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = 
ALL_INSTANCE_METRICS
+
   private[state] def latestVersion: Long = rocksDB.getLatestVersion()
 
   /** Internal fields and methods */
@@ -888,6 +897,10 @@ object RocksDBStateStoreProvider {
     CUSTOM_METRIC_COMPACT_WRITTEN_BYTES, CUSTOM_METRIC_FLUSH_WRITTEN_BYTES,
     CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, 
CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS,
     CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, 
CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES)
+
+  val CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED = 
StateStoreSnapshotLastUploadInstanceMetric()
+
+  val ALL_INSTANCE_METRICS = Seq(CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED)
 }
 
 /** [[StateStoreChangeDataReader]] implementation for 
[[RocksDBStateStoreProvider]] */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 928e2b0e9b99..33a21c79f3db 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -249,12 +249,17 @@ class WrappedReadStateStore(store: StateStore) extends 
ReadStateStore {
  * @param memoryUsedBytes Memory used by the state store
  * @param customMetrics   Custom implementation-specific metrics
  *                        The metrics reported through this must have the same 
`name` as those
- *                        reported by `StateStoreProvider.customMetrics`.
+ *                        reported by 
`StateStoreProvider.supportedCustomMetrics`.
+ * @param instanceMetrics Custom implementation-specific metrics that are 
specific to state stores
+ *                        The metrics reported through this must have the same 
`name` as those
+ *                        reported by 
`StateStoreProvider.supportedInstanceMetrics`,
+ *                        including partition id and store name.
  */
 case class StateStoreMetrics(
     numKeys: Long,
     memoryUsedBytes: Long,
-    customMetrics: Map[StateStoreCustomMetric, Long])
+    customMetrics: Map[StateStoreCustomMetric, Long],
+    instanceMetrics: Map[StateStoreInstanceMetric, Long] = Map.empty)
 
 /**
  * State store checkpoint information, used to pass checkpointing information 
from executors
@@ -284,7 +289,8 @@ object StateStoreMetrics {
     StateStoreMetrics(
       allMetrics.map(_.numKeys).sum,
       allMetrics.map(_.memoryUsedBytes).sum,
-      combinedCustomMetrics)
+      combinedCustomMetrics,
+      allMetrics.flatMap(_.instanceMetrics).toMap)
   }
 }
 
@@ -321,6 +327,86 @@ case class StateStoreCustomTimingMetric(name: String, 
desc: String) extends Stat
     SQLMetrics.createTimingMetric(sparkContext, desc)
 }
 
+trait StateStoreInstanceMetric {
+  def metricPrefix: String
+  def descPrefix: String
+  def partitionId: Option[Int]
+  def storeName: String
+  def initValue: Long
+
+  def createSQLMetric(sparkContext: SparkContext): SQLMetric
+
+  /**
+   * Defines how instance metrics are selected for progress reporting.
+   * Metrics are sorted by value using this ordering, and only the first N 
metrics are displayed.
+   * For example, the highest N metrics by value should use 
Ordering.Long.reverse.
+   */
+  def ordering: Ordering[Long]
+
+  /** Should this instance metric be reported if it is unchanged from its 
initial value */
+  def ignoreIfUnchanged: Boolean
+
+  /**
+   * Defines how to merge metric values from different executors for the same 
state store
+   * instance in situations like speculative execution or provider unloading. 
In most cases,
+   * the original metric value is at its initial value.
+   */
+  def combine(originalMetric: SQLMetric, value: Long): Long
+
+  def name: String = {
+    assert(partitionId.isDefined, "Partition ID must be defined for instance 
metric name")
+    s"$metricPrefix.partition_${partitionId.get}_$storeName"
+  }
+
+  def desc: String = {
+    assert(partitionId.isDefined, "Partition ID must be defined for instance 
metric description")
+    s"$descPrefix (partitionId = ${partitionId.get}, storeName = $storeName)"
+  }
+
+  def withNewId(partitionId: Int, storeName: String): StateStoreInstanceMetric
+}
+
+case class StateStoreSnapshotLastUploadInstanceMetric(
+    partitionId: Option[Int] = None,
+    storeName: String = StateStoreId.DEFAULT_STORE_NAME)
+  extends StateStoreInstanceMetric {
+
+  override def metricPrefix: String = "SnapshotLastUploaded"
+
+  override def descPrefix: String = {
+    "The last uploaded version of the snapshot for a specific state store 
instance"
+  }
+
+  override def initValue: Long = -1L
+
+  override def createSQLMetric(sparkContext: SparkContext): SQLMetric = {
+    SQLMetrics.createSizeMetric(sparkContext, desc, initValue)
+  }
+
+  override def ordering: Ordering[Long] = Ordering.Long
+
+  override def ignoreIfUnchanged: Boolean = false
+
+  override def combine(originalMetric: SQLMetric, value: Long): Long = {
+    // Check for cases where the initial value is less than 0, forcing 
metric.value to
+    // convert it to 0. Since the last uploaded snapshot version can have an 
initial
+    // value of -1, we need special handling to avoid turning the -1 into a 0.
+    if (originalMetric.isZero) {
+      value
+    } else {
+      // Use max to grab the most recent snapshot version across all executors
+      // of the same store instance
+      Math.max(originalMetric.value, value)
+    }
+  }
+
+  override def withNewId(
+      partitionId: Int,
+      storeName: String): StateStoreSnapshotLastUploadInstanceMetric = {
+    copy(partitionId = Some(partitionId), storeName = storeName)
+  }
+}
+
 sealed trait KeyStateEncoderSpec {
   def keySchema: StructType
   def jsonValue: JValue
@@ -495,9 +581,16 @@ trait StateStoreProvider {
   /**
    * Optional custom metrics that the implementation may want to report.
    * @note The StateStore objects created by this provider must report the 
same custom metrics
-   * (specifically, same names) through `StateStore.metrics`.
+   * (specifically, same names) through `StateStore.metrics.customMetrics`.
    */
   def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil
+
+  /**
+   * Optional custom state store instance metrics that the implementation may 
want to report.
+   * @note The StateStore objects created by this provider must report the 
same instance metrics
+   * (specifically, same names) through `StateStore.metrics.instanceMetrics`.
+   */
+  def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = Seq.empty
 }
 
 object StateStoreProvider {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index f487ddf4252c..66ab0006c498 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -447,7 +447,9 @@ class SymmetricHashJoinStateManager(
       keyToNumValuesMetrics.memoryUsedBytes + 
keyWithIndexToValueMetrics.memoryUsedBytes,
       keyWithIndexToValueMetrics.customMetrics.map {
         case (metric, value) => (metric.withNewDesc(desc = 
newDesc(metric.desc)), value)
-      }
+      },
+      // We want to collect instance metrics from both state stores
+      keyWithIndexToValueMetrics.instanceMetrics ++ 
keyToNumValuesMetrics.instanceMetrics
     )
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index fc269897edd6..64bbb998ca59 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -203,20 +203,46 @@ trait StateStoreWriter
 
   def operatorStateMetadataVersion: Int = 1
 
-  override lazy val metrics = statefulOperatorCustomMetrics ++ Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
-    "numRowsDroppedByWatermark" -> SQLMetrics.createMetric(sparkContext,
-      "number of rows which are dropped by watermark"),
-    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
updated state rows"),
-    "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
update"),
-    "numRemovedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
removed state rows"),
-    "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time 
to remove"),
-    "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
commit changes"),
-    "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by 
state"),
-    "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
-      "number of state store instances")
-  ) ++ stateStoreCustomMetrics ++ pythonMetrics
+  override lazy val metrics = {
+    // Lazy initialize instance metrics, but do not include these with regular 
metrics
+    instanceMetrics
+    statefulOperatorCustomMetrics ++ Map(
+      "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows"),
+      "numRowsDroppedByWatermark" -> SQLMetrics
+        .createMetric(sparkContext, "number of rows which are dropped by 
watermark"),
+      "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
total state rows"),
+      "numUpdatedStateRows" -> SQLMetrics
+        .createMetric(sparkContext, "number of updated state rows"),
+      "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time 
to update"),
+      "numRemovedStateRows" -> SQLMetrics
+        .createMetric(sparkContext, "number of removed state rows"),
+      "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time 
to remove"),
+      "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
commit changes"),
+      "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used 
by state"),
+      "numStateStoreInstances" -> SQLMetrics
+        .createMetric(sparkContext, "number of state store instances")
+    ) ++ stateStoreCustomMetrics ++ pythonMetrics
+  }
+
+  /**
+   * Map of all instance metrics (including partition ID and store names) to
+   * their SQLMetric counterpart.
+   *
+   * The instance metric objects hold additional information on how to report 
these metrics,
+   * while the SQLMetric objects store the metric values.
+   *
+   * This map is similar to the metrics map, but needs to be kept separate to 
prevent propagating
+   * all initialized instance metrics to SparkUI.
+   */
+  lazy val instanceMetrics: Map[StateStoreInstanceMetric, SQLMetric] =
+    stateStoreInstanceMetrics
+
+  override def resetMetrics(): Unit = {
+    super.resetMetrics()
+    instanceMetrics.valuesIterator.foreach(_.reset())
+  }
+
+  val stateStoreNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)
 
   // This method is only used to fetch the state schema directory path for
   // operators that use StateSchemaV3, as prior versions only use a single
@@ -320,11 +346,43 @@ trait StateStoreWriter
    * the driver after this SparkPlan has been executed and metrics have been 
updated.
    */
   def getProgress(): StateOperatorProgress = {
+    val instanceMetricsToReport = instanceMetrics
+      .filter {
+        case (metricConf, sqlMetric) =>
+          // Keep instance metrics that are updated or aren't marked to be 
ignored,
+          // as their initial value could still be important.
+          !metricConf.ignoreIfUnchanged || !sqlMetric.isZero
+      }
+      .groupBy {
+        // Group all instance metrics underneath their common metric prefix
+        // to ignore partition and store names.
+        case (metricConf, sqlMetric) => metricConf.metricPrefix
+      }
+      .flatMap {
+        case (_, metrics) =>
+          // Select at most N metrics based on the metric's defined ordering
+          // to report to the driver. For example, ascending order would be 
taking the N smallest.
+          val metricConf = metrics.head._1
+          metrics
+            .map {
+              case (metricConf, sqlMetric) =>
+                // Use metric name as it will be combined with custom metrics 
in progress reports.
+                // All metrics that are at their initial value at this stage 
should not be ignored
+                // and should show their real initial value.
+                metricConf.name -> (if (sqlMetric.isZero) metricConf.initValue
+                                    else sqlMetric.value)
+            }
+            .toSeq
+            .sortBy(_._2)(metricConf.ordering)
+            .take(conf.numStateStoreInstanceMetricsToReport)
+            .toMap
+      }
     val customMetrics = (stateStoreCustomMetrics ++ 
statefulOperatorCustomMetrics)
       .map(entry => entry._1 -> longMetric(entry._1).value)
+    val allCustomMetrics = customMetrics ++ instanceMetricsToReport
 
     val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] =
-      new java.util.HashMap(customMetrics.transform((_, v) => 
long2Long(v)).asJava)
+      new java.util.HashMap(allCustomMetrics.transform((_, v) => 
long2Long(v)).asJava)
 
     // We now don't report number of shuffle partitions inside the state 
operator. Instead,
     // it will be filled when the stream query progress is reported
@@ -373,9 +431,8 @@ trait StateStoreWriter
     val storeMetrics = store.metrics
     longMetric("numTotalStateRows") += storeMetrics.numKeys
     longMetric("stateMemory") += storeMetrics.memoryUsedBytes
-    storeMetrics.customMetrics.foreach { case (metric, value) =>
-      longMetric(metric.name) += value
-    }
+    setStoreCustomMetrics(storeMetrics.customMetrics)
+    setStoreInstanceMetrics(storeMetrics.instanceMetrics)
 
     if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) {
       // Set the state store checkpoint information for the driver to collect
@@ -391,6 +448,22 @@ trait StateStoreWriter
     }
   }
 
+  protected def setStoreCustomMetrics(customMetrics: 
Map[StateStoreCustomMetric, Long]): Unit = {
+    customMetrics.foreach {
+      case (metric, value) =>
+        longMetric(metric.name) += value
+    }
+  }
+
+  protected def setStoreInstanceMetrics(
+      otherStoreInstanceMetrics: Map[StateStoreInstanceMetric, Long]): Unit = {
+    otherStoreInstanceMetrics.foreach {
+      case (metric, value) =>
+        // Update the metric's value based on the defined combine method
+        instanceMetrics(metric).set(metric.combine(instanceMetrics(metric), 
value))
+    }
+  }
+
   private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
     val provider = StateStoreProvider.create(conf.stateStoreProviderClass)
     provider.supportedCustomMetrics.map {
@@ -398,6 +471,20 @@ trait StateStoreWriter
     }.toMap
   }
 
+  private def stateStoreInstanceMetrics: Map[StateStoreInstanceMetric, 
SQLMetric] = {
+    val provider = StateStoreProvider.create(conf.stateStoreProviderClass)
+    val maxPartitions = 
stateInfo.map(_.numPartitions).getOrElse(conf.defaultNumShufflePartitions)
+
+    (0 until maxPartitions).flatMap { partitionId =>
+      provider.supportedInstanceMetrics.flatMap { metric =>
+        stateStoreNames.map { storeName =>
+          val metricWithPartition = metric.withNewId(partitionId, storeName)
+          (metricWithPartition, 
metricWithPartition.createSQLMetric(sparkContext))
+        }
+      }
+    }.toMap
+  }
+
   /**
    * Set of stateful operator custom metrics. These are captured as part of 
the generic
    * key-value map [[StateOperatorProgress.customMetrics]].
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
index d35bbd49de0d..da4f685aaff8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
@@ -183,6 +183,9 @@ class CkptIdCollectingStateStoreProviderWrapper extends 
StateStoreProvider {
 
   override def supportedCustomMetrics: Seq[StateStoreCustomMetric] =
     innerProvider.supportedCustomMetrics
+
+  override def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] =
+    innerProvider.supportedInstanceMetrics
 }
 
 class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 1f4fd7f79571..a0f6d67da5f7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -19,21 +19,36 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
 
-import scala.jdk.CollectionConverters.SetHasAsScala
+import scala.concurrent.duration.DurationInt
+import scala.jdk.CollectionConverters.{MapHasAsScala, SetHasAsScala}
 
 import org.scalatest.time.{Minute, Span}
 
 import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StreamingQueryWrapper}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.functions.{count, expr}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.OutputMode.Update
 import org.apache.spark.util.Utils
 
+// SkipMaintenanceOnCertainPartitionsProvider is a test-only provider that 
skips running
+// maintenance for partitions 0 and 1 (these are arbitrary choices). This is 
used to test
+// snapshot upload lag can be observed through StreamingQueryProgress metrics.
+class SkipMaintenanceOnCertainPartitionsProvider extends 
RocksDBStateStoreProvider {
+  override def doMaintenance(): Unit = {
+    if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) {
+      return
+    }
+    super.doMaintenance()
+  }
+}
+
 class RocksDBStateStoreIntegrationSuite extends StreamTest
   with AlsoTestWithRocksDBFeatures {
   import testImplicits._
 
+  private val SNAPSHOT_LAG_METRIC_PREFIX = "SnapshotLastUploaded.partition_"
+
   testWithColumnFamilies("RocksDBStateStore",
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
     withTempDir { dir =>
@@ -108,7 +123,8 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
               "rocksdbTotalCompactionLatencyMs", "rocksdbWriterStallLatencyMs",
               "rocksdbTotalBytesReadThroughIterator", 
"rocksdbTotalBytesWrittenByFlush",
               "rocksdbPinnedBlocksMemoryUsage", 
"rocksdbNumInternalColFamiliesKeys",
-              "rocksdbNumExternalColumnFamilies", 
"rocksdbNumInternalColumnFamilies"))
+              "rocksdbNumExternalColumnFamilies", 
"rocksdbNumInternalColumnFamilies",
+              "SnapshotLastUploaded.partition_0_default"))
           }
         } finally {
           query.stop()
@@ -270,4 +286,277 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     assert(changelogVersionsPresent(dirForPartition0) == List(3L, 4L))
     assert(snapshotVersionsPresent(dirForPartition0).contains(5L))
   }
+
+  private def snapshotLagMetricName(
+      partitionId: Long,
+      storeName: String = StateStoreId.DEFAULT_STORE_NAME): String = {
+    s"$SNAPSHOT_LAG_METRIC_PREFIX${partitionId}_$storeName"
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly with 
RocksDBStateStoreProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "3"
+    ) {
+      withTempDir { checkpointDir =>
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS().dropDuplicates()
+
+        testStream(result, outputMode = OutputMode.Update)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          ProcessAllAvailable(),
+          AddData(inputData, "b"),
+          ProcessAllAvailable(),
+          CheckNewAnswer("a", "b"),
+          Execute { q =>
+            // Make sure only smallest K active metrics are published
+            eventually(timeout(10.seconds)) {
+              val instanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                instanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              assert(instanceMetrics.forall(_._2 == 1))
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly with " +
+    "SkipMaintenanceOnCertainPartitionsProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "3"
+    ) {
+      withTempDir { checkpointDir =>
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS().dropDuplicates()
+
+        testStream(result, outputMode = OutputMode.Update)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          ProcessAllAvailable(),
+          AddData(inputData, "b"),
+          ProcessAllAvailable(),
+          CheckNewAnswer("a", "b"),
+          Execute { q =>
+            // Partitions getting skipped (id 0 and 1) do not have an uploaded 
version, leaving
+            // those instance metrics as -1.
+            eventually(timeout(10.seconds)) {
+              assert(
+                q.lastProgress
+                  .stateOperators(0)
+                  .customMetrics
+                  .get(snapshotLagMetricName(0)) === -1
+              )
+              assert(
+                q.lastProgress
+                  .stateOperators(0)
+                  .customMetrics
+                  .get(snapshotLagMetricName(1)) === -1
+              )
+              // Make sure only smallest K active metrics are published
+              val instanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                instanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              // Two metrics published are -1, the remainder should all be 1 
as they
+              // uploaded properly.
+              assert(
+                instanceMetrics.count(_._2 == 1) == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT) - 2
+              )
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly for join 
queries with " +
+    "RocksDBStateStoreProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "10"
+    ) {
+      withTempDir { checkpointDir =>
+        val input1 = MemoryStream[Int]
+        val input2 = MemoryStream[Int]
+
+        val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) 
as "leftValue")
+        val df2 = input2
+          .toDF()
+          .select($"value" as "rightKey", ($"value" * 3) as "rightValue")
+        val joined = df1.join(df2, expr("leftKey = rightKey"))
+
+        testStream(joined)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(input1, 1, 5),
+          ProcessAllAvailable(),
+          AddData(input2, 1, 5, 10),
+          ProcessAllAvailable(),
+          CheckNewAnswer((1, 2, 1, 3), (5, 10, 5, 15)),
+          Execute { q =>
+            eventually(timeout(10.seconds)) {
+              // Make sure only smallest K active metrics are published.
+              // There are 5 * 4 = 20 metrics in total because of join, but 
only 10 are published.
+              val instanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                instanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              // All state store instances should have uploaded a version
+              assert(instanceMetrics.forall(_._2 == 1))
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly for join 
queries with " +
+    "SkipMaintenanceOnCertainPartitionsProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "10"
+    ) {
+      withTempDir { checkpointDir =>
+        val input1 = MemoryStream[Int]
+        val input2 = MemoryStream[Int]
+
+        val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) 
as "leftValue")
+        val df2 = input2
+          .toDF()
+          .select($"value" as "rightKey", ($"value" * 3) as "rightValue")
+        val joined = df1.join(df2, expr("leftKey = rightKey"))
+
+        testStream(joined)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(input1, 1, 5),
+          ProcessAllAvailable(),
+          AddData(input2, 1, 5, 10),
+          ProcessAllAvailable(),
+          CheckNewAnswer((1, 2, 1, 3), (5, 10, 5, 15)),
+          Execute { q =>
+            eventually(timeout(10.seconds)) {
+              // Make sure only smallest K active metrics are published.
+              // There are 5 * 4 = 20 metrics in total because of join, but 
only 10 are published.
+              val allInstanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              val badInstanceMetrics = allInstanceMetrics.filterKeys(
+                k =>
+                  k.startsWith(snapshotLagMetricName(0, "")) ||
+                  k.startsWith(snapshotLagMetricName(1, ""))
+              )
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                allInstanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              // Two ids are blocked, each with four state stores
+              assert(badInstanceMetrics.count(_._2 == -1) == 2 * 4)
+              // The rest should have uploaded a version
+              assert(
+                allInstanceMetrics.count(_._2 == 1) == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT) - 2 
* 4
+              )
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify RocksDB instance metrics are not collected in 
execution plan"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName
+    ) {
+      withTempDir { checkpointDir =>
+        val input1 = MemoryStream[Int]
+        val input2 = MemoryStream[Int]
+
+        val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) 
as "leftValue")
+        val df2 = input2
+          .toDF()
+          .select($"value" as "rightKey", ($"value" * 3) as "rightValue")
+        val joined = df1.join(df2, expr("leftKey = rightKey"))
+
+        testStream(joined)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(input1, 1, 5),
+          ProcessAllAvailable(),
+          AddData(input2, 1, 5, 10),
+          ProcessAllAvailable(),
+          CheckNewAnswer((1, 2, 1, 3), (5, 10, 5, 15)),
+          AssertOnQuery { q =>
+            // Go through all elements in the execution plan and verify none 
of the metrics
+            // are generated from RocksDB's snapshot lag instance metrics.
+            q.lastExecution.executedPlan
+              .collect {
+                case node => node.metrics
+              }
+              .forall { nodeMetrics =>
+                nodeMetrics.forall(metric => 
!metric._1.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              }
+          },
+          StopStream
+        )
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to