This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new afc45f08c0de [SPARK-54121][SS] Automatic Snapshot Repair for State 
store
afc45f08c0de is described below

commit afc45f08c0de25d0652cf9c6863a36cf97d0cf96
Author: micheal-o <[email protected]>
AuthorDate: Tue Nov 4 10:59:23 2025 -0800

    [SPARK-54121][SS] Automatic Snapshot Repair for State store
    
    ### What changes were proposed in this pull request?
    
    Today, the engine currently treats both the changelog and snapshot files 
with the same importance, so when a state store needs to be loaded it reads the 
latest snapshot and applies the subsequent changes on it. If the snapshot is 
bad or corrupted, then the query will fail and be completely down and blocked, 
needing manual intervention. This leads to user clearing their query checkpoint 
and having to do full recomputation.
    
    This shouldn’t be the case. The changelog should be treated as the “source 
of truth” and the snapshot is just a disposable materialization of the log.
    
    Introducing Automatic snapshot repair, which will automatically repair the 
checkpoint by skipping bad snapshots and rebuilding the current state from the 
last good snapshot (works even if there’s none) and applying the changelogs on 
it. This eliminates the need for manual intervention and unblocks the pipeline 
to keep it running.
    
    Also emit metrics about number of state stores that were auto repaired in a 
given batch, so that you can build alert and dashboard for it.
    
    ### Why are the changes needed?
    
    Automatic failure recovery
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52821 from micheal-o/auto_snapshot_repair.
    
    Authored-by: micheal-o <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   5 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  46 +++++
 .../streaming/ClientStreamingQuerySuite.scala      |   1 +
 .../streaming/state/AutoSnapshotLoader.scala       | 193 +++++++++++++++++++++
 .../state/HDFSBackedStateStoreProvider.scala       | 108 ++++++++----
 .../sql/execution/streaming/state/RocksDB.scala    |  83 ++++++---
 .../streaming/state/RocksDBFileManager.scala       |   6 +
 .../state/RocksDBStateStoreProvider.scala          |   8 +-
 .../execution/streaming/state/StateStoreConf.scala |  11 ++
 .../streaming/state/StateStoreErrors.scala         |  25 +++
 .../streaming/state/AutoSnapshotLoaderSuite.scala  | 177 +++++++++++++++++++
 .../state/RocksDBStateStoreIntegrationSuite.scala  |   3 +
 .../execution/streaming/state/RocksDBSuite.scala   |  77 ++++++++
 .../streaming/state/StateStoreSuite.scala          |  88 +++++++++-
 14 files changed, 773 insertions(+), 58 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index a34ceb9f1145..f7eb1e63d7bd 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -320,6 +320,11 @@
       "An error occurred during loading state."
     ],
     "subClass" : {
+      "AUTO_SNAPSHOT_REPAIR_FAILED" : {
+        "message" : [
+          "Failed to load snapshot version <latestSnapshot> for state store 
<stateStoreId>. An attempt to auto repair using snapshot versions 
(<selectedSnapshots>) out of available snapshots (<eligibleSnapshots>) also 
failed."
+        ]
+      },
       "CANNOT_FIND_BASE_SNAPSHOT_CHECKPOINT" : {
         "message" : [
           "Cannot find a base snapshot checkpoint with lineage: <lineage>."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b8907629ad37..d2d7edc65121 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2546,6 +2546,43 @@ object SQLConf {
       .intConf
       .createWithDefault(10)
 
+  val STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED =
+    buildConf("spark.sql.streaming.stateStore.autoSnapshotRepair.enabled")
+      .internal()
+      .doc("When true, enables automatic repair of state store snapshot, when 
a bad snapshot is " +
+        "detected while loading the state store, to prevent the query from 
failing. " +
+        "Typically, queries will fail when they are unable to load a snapshot, 
" +
+        "but this helps recover by skipping the bad snapshot and uses the 
change files." +
+        "NOTE: For RocksDB state store, changelog checkpointing must be 
enabled")
+      .version("4.1.0")
+      .booleanConf
+      // Disable in tests, so that tests will fail if they encounter bad 
snapshot
+      .createWithDefault(!Utils.isTesting)
+
+  val STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING =
+    
buildConf("spark.sql.streaming.stateStore.autoSnapshotRepair.numFailuresBeforeActivating")
+      .internal()
+      .doc(
+        "When autoSnapshotRepair is enabled, it will wait for the specified 
number of snapshot " +
+          "load failures, before it attempts to repair."
+      )
+      .version("4.1.0")
+      .intConf
+      .checkValue(k => k > 0, "Must allow at least 1 failure before activating 
autoSnapshotRepair")
+      .createWithDefault(1)
+
+  val STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY =
+    
buildConf("spark.sql.streaming.stateStore.autoSnapshotRepair.maxChangeFileReplay")
+      .internal()
+      .doc(
+        "When autoSnapshotRepair is enabled, this specifies the maximum number 
of change " +
+          "files allowed to be replayed to rebuild state due to bad snapshots."
+      )
+      .version("4.1.0")
+      .intConf
+      .checkValue(k => k > 0, "Must allow at least 1 change file replay")
+      .createWithDefault(50)
+
   val STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT =
     
buildConf("spark.sql.streaming.stateStore.numStateStoreInstanceMetricsToReport")
       .internal()
@@ -6729,6 +6766,15 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def stateStoreMinDeltasForSnapshot: Int = 
getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 
+  def stateStoreAutoSnapshotRepairEnabled: Boolean =
+    getConf(STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED)
+
+  def stateStoreAutoSnapshotRepairNumFailuresBeforeActivating: Int =
+    getConf(STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING)
+
+  def stateStoreAutoSnapshotRepairMaxChangeFileReplay: Int =
+    getConf(STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY)
+
   def stateStoreFormatValidationEnabled: Boolean = 
getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)
 
   def stateStoreSkipNullsForStreamStreamJoins: Boolean =
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
index 580b8e1114f9..ef3bb4711c85 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
@@ -103,6 +103,7 @@ class ClientStreamingQuerySuite extends QueryTest with 
RemoteSparkSession with L
           lastProgress.stateOperators.head.customMetrics.keySet().asScala == 
Set(
             "loadedMapCacheHitCount",
             "loadedMapCacheMissCount",
+            "numSnapshotsAutoRepaired",
             "stateOnCurrentVersionSizeBytes",
             "SnapshotLastUploaded.partition_0_default"))
         assert(lastProgress.sources.nonEmpty)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoader.scala
new file mode 100644
index 000000000000..d94f10d49fbd
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoader.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.collection.immutable.ArraySeq
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.fs.{Path, PathFilter}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.{NUM_RETRIES, NUM_RETRY, VERSION_NUM}
+import 
org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
+
+/**
+ * [[AutoSnapshotLoader]] is used to handle loading state store snapshot 
version from the
+ * checkpoint directory. It supports Auto snapshot repair, which will 
automatically handle
+ * corrupt snapshots and skip them, by using another snapshot version before 
the corrupt one.
+ * If no snapshot exists before the corrupt one, then it will use the 0 
version snapshot
+ * (represents initial/empty snapshot).
+ *
+ * @param autoSnapshotRepairEnabled If true, it will handle corrupt snapshot
+ * @param numFailuresBeforeActivating If auto snapshot repair is enabled,
+ *                                    number of failures before activating it
+ * @param maxChangeFileReplay If auto snapshot repair is enabled, maximum 
difference between
+ *                            the requested snapshot version and the selected 
snapshot version
+ * @param loggingId To append to log messages
+ * */
+abstract class AutoSnapshotLoader(
+    autoSnapshotRepairEnabled: Boolean,
+    numFailuresBeforeActivating: Int,
+    maxChangeFileReplay: Int,
+    loggingId: String = "") extends Logging {
+
+  override protected def logName: String = s"${super.logName} $loggingId"
+
+  /** Called before loading a snapshot from the checkpoint directory */
+  protected def beforeLoad(): Unit
+
+  /**
+   * Attempt to load the specified snapshot version from the checkpoint 
directory.
+   * Should throw an exception if the snapshot is corrupt.
+   * @note Must support loading version 0
+   * */
+  protected def loadSnapshotFromCheckpoint(snapshotVersion: Long): Unit
+
+  /** Called when load fails, to do any necessary cleanup/variable reset */
+  protected def onLoadSnapshotFromCheckpointFailure(): Unit
+
+  /** Get a list of eligible snapshot versions in the checkpoint directory 
that can be loaded */
+  protected def getEligibleSnapshots(versionToLoad: Long): Seq[Long]
+
+  /**
+   * Load the latest snapshot for the specified version from the checkpoint 
directory.
+   * If Auto snapshot repair is enabled, the snapshot version loaded may be 
lower than
+   * the latest snapshot version, if the latest is corrupt.
+   *
+   * @param versionToLoad The version to load latest snapshot for
+   * @return The actual loaded snapshot version and if it was due to auto 
repair
+   * */
+  def loadSnapshot(versionToLoad: Long): (Long, Boolean) = {
+    val eligibleSnapshots =
+      (getEligibleSnapshots(versionToLoad) :+ 0L) // always include the 
initial snapshot
+      .distinct // Ensure no duplicate version numbers
+      .sorted(Ordering[Long].reverse)
+
+    // Start with the latest snapshot
+    val firstEligibleSnapshot = eligibleSnapshots.head
+
+    // no retry if auto snapshot repair is not enabled
+    val maxNumFailures = if (autoSnapshotRepairEnabled) 
numFailuresBeforeActivating else 1
+    var numFailuresForFirstSnapshot = 0
+    var lastException: Throwable = null
+    var loadedSnapshot: Option[Long] = None
+    while (loadedSnapshot.isEmpty && numFailuresForFirstSnapshot < 
maxNumFailures) {
+      beforeLoad() // if this fails, then we should fail
+      try {
+        // try to load the first eligible snapshot
+        loadSnapshotFromCheckpoint(firstEligibleSnapshot)
+        loadedSnapshot = Some(firstEligibleSnapshot)
+      } catch {
+        // Swallow only if auto snapshot repair is enabled
+        // If auto snapshot repair is not enabled, we should fail immediately
+        case NonFatal(e) if autoSnapshotRepairEnabled =>
+          onLoadSnapshotFromCheckpointFailure()
+          numFailuresForFirstSnapshot += 1
+          logError(log"Failed to load snapshot version " +
+            log"${MDC(VERSION_NUM, firstEligibleSnapshot)}, " +
+            log"attempt ${MDC(NUM_RETRY, numFailuresForFirstSnapshot)} out of 
" +
+            log"${MDC(NUM_RETRIES, maxNumFailures)} attempts", e)
+          lastException = e
+        case e: Throwable =>
+          onLoadSnapshotFromCheckpointFailure()
+          throw e
+      }
+    }
+
+    var autoRepairCompleted = false
+    if (loadedSnapshot.isEmpty) {
+      // we would only get here if auto snapshot repair is enabled
+      assert(autoSnapshotRepairEnabled)
+
+      val remainingEligibleSnapshots = if (eligibleSnapshots.length > 1) {
+        // skip the first snapshot, since we already tried it
+        eligibleSnapshots.tail
+      } else {
+        // no more snapshots to try
+        Seq.empty
+      }
+
+      // select remaining snapshots that are within the maxChangeFileReplay 
limit
+      val selectedRemainingSnapshots = remainingEligibleSnapshots.filter(
+        s => versionToLoad - s <= maxChangeFileReplay)
+
+      logInfo(log"Attempting to auto repair snapshot by skipping " +
+        log"snapshot version ${MDC(VERSION_NUM, firstEligibleSnapshot)} " +
+        log"and trying to load with one of the selected snapshots " +
+        log"${MDC(VERSION_NUM, selectedRemainingSnapshots)}, out of eligible 
snapshots " +
+        log"${MDC(VERSION_NUM, remainingEligibleSnapshots)}. " +
+        log"maxChangeFileReplay: ${MDC(VERSION_NUM, maxChangeFileReplay)}")
+
+      // Now try to load using any of the selected snapshots,
+      // remember they are sorted in descending order
+      for (snapshotVersion <- selectedRemainingSnapshots if 
loadedSnapshot.isEmpty) {
+        beforeLoad() // if this fails, then we should fail
+        try {
+          loadSnapshotFromCheckpoint(snapshotVersion)
+          loadedSnapshot = Some(snapshotVersion)
+          logInfo(log"Successfully loaded snapshot version " +
+            log"${MDC(VERSION_NUM, snapshotVersion)}. Repair complete.")
+        } catch {
+          case NonFatal(e) =>
+            logError(log"Failed to load snapshot version " +
+              log"${MDC(VERSION_NUM, snapshotVersion)}, will retry repair with 
" +
+              log"the next eligible snapshot version", e)
+            onLoadSnapshotFromCheckpointFailure()
+            lastException = e
+        }
+      }
+
+      if (loadedSnapshot.isEmpty) {
+        // we tried all eligible snapshots and failed to load any of them
+        logError(log"Auto snapshot repair failed to load any snapshot:" +
+          log" latestSnapshotVersion: ${MDC(VERSION_NUM, 
firstEligibleSnapshot)}, " +
+          log"attemptedSnapshots: ${MDC(VERSION_NUM, 
selectedRemainingSnapshots)}, " +
+          log"eligibleSnapshots:  ${MDC(VERSION_NUM, 
remainingEligibleSnapshots)}, " +
+          log"maxChangeFileReplay: ${MDC(VERSION_NUM, maxChangeFileReplay)}", 
lastException)
+        throw StateStoreErrors.autoSnapshotRepairFailed(
+          loggingId, firstEligibleSnapshot, selectedRemainingSnapshots, 
remainingEligibleSnapshots,
+          lastException)
+      } else {
+        autoRepairCompleted = true
+      }
+    }
+
+    // we would only get here if we successfully loaded a snapshot
+    (loadedSnapshot.get, autoRepairCompleted)
+  }
+}
+
+object SnapshotLoaderHelper {
+  /** Get all the snapshot versions that can be used to load this version */
+  def getEligibleSnapshotsForVersion(
+      version: Long,
+      fm: CheckpointFileManager,
+      dfsPath: Path,
+      pathFilter: PathFilter,
+      fileSuffix: String): Seq[Long] = {
+    if (fm.exists(dfsPath)) {
+      ArraySeq.unsafeWrapArray(
+        fm.list(dfsPath, pathFilter)
+          .map(_.getPath.getName.stripSuffix(fileSuffix))
+          .map(_.toLong)
+      ).filter(_ <= version)
+    } else {
+      Seq(0L)
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index aa4fa9bfaf62..f1c9c94e7bf8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import java.io._
 import java.util
 import java.util.{Locale, UUID}
-import java.util.concurrent.atomic.{AtomicLong, LongAdder}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, LongAdder}
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
@@ -295,7 +295,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   def getMetricsForProvider(): Map[String, Long] = synchronized {
     Map("memoryUsedBytes" -> SizeEstimator.estimate(loadedMaps),
       metricLoadedMapCacheHit.name -> loadedMapCacheHitCount.sum(),
-      metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum())
+      metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum(),
+      metricNumSnapshotsAutoRepaired.name -> (if 
(performedSnapshotAutoRepair.get()) 1 else 0)
+    )
   }
 
   /** Get the state store for making updates to create a new `version` of the 
store. */
@@ -324,6 +326,8 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       if (version < 0) {
         throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
       }
+
+      performedSnapshotAutoRepair.set(false)
       val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
       if (version > 0) {
         newMap.putAll(loadMap(version))
@@ -426,6 +430,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
   override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {
     metricStateOnCurrentVersionSizeBytes :: metricLoadedMapCacheHit :: 
metricLoadedMapCacheMiss ::
+    metricNumSnapshotsAutoRepaired ::
       Nil
   }
 
@@ -471,6 +476,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       mgr
     }
   }
+  private val onlySnapshotFiles = new PathFilter {
+    override def accept(path: Path): Boolean = 
path.toString.endsWith(".snapshot")
+  }
   private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new 
SparkConf)
 
   private val loadedMapCacheHitCount: LongAdder = new LongAdder
@@ -479,6 +487,8 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   // This is updated when the maintenance task writes the snapshot file and 
read by the task
   // thread. -1 represents no version has ever been uploaded.
   private val lastUploadedSnapshotVersion: AtomicLong = new AtomicLong(-1L)
+  // Was snapshot auto repair performed when loading the current version
+  private val performedSnapshotAutoRepair: AtomicBoolean = new 
AtomicBoolean(false)
 
   private lazy val metricStateOnCurrentVersionSizeBytes: 
StateStoreCustomSizeMetric =
     StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes",
@@ -492,6 +502,10 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     StateStoreCustomSumMetric("loadedMapCacheMissCount",
       "count of cache miss on states cache in provider")
 
+  private lazy val metricNumSnapshotsAutoRepaired: StateStoreCustomMetric =
+    StateStoreCustomSumMetric("numSnapshotsAutoRepaired",
+    "number of snapshots that were automatically repaired during store load")
+
   private lazy val instanceMetricSnapshotLastUpload: StateStoreInstanceMetric =
     StateStoreSnapshotLastUploadInstanceMetric()
 
@@ -593,52 +607,78 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     loadedMapCacheMissCount.increment()
 
     val (result, elapsedMs) = Utils.timeTakenMs {
-      val snapshotCurrentVersionMap = readSnapshotFile(version)
-      if (snapshotCurrentVersionMap.isDefined) {
-        synchronized { putStateIntoStateCacheMap(version, 
snapshotCurrentVersionMap.get) }
+      val (loadedVersion, loadedMap) = loadSnapshot(version)
+      val finalMap = if (loadedVersion == version) {
+        loadedMap
+      } else {
+        // Load all the deltas from the version after the loadedVersion up to 
the target version.
+        // The loadedVersion is the one with a full snapshot, so it doesn't 
need deltas.
+        val resultMap = HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey)
+        resultMap.putAll(loadedMap)
+        for (deltaVersion <- loadedVersion + 1 to version) {
+          updateFromDeltaFile(deltaVersion, resultMap)
+        }
+        resultMap
+      }
 
-        // Report the loaded snapshot's version to the coordinator
-        reportSnapshotUploadToCoordinator(version)
+      // Synchronize and update the state cache map
+      synchronized { putStateIntoStateCacheMap(version, finalMap) }
 
-        return snapshotCurrentVersionMap.get
-      }
+      // Report the snapshot found to the coordinator
+      reportSnapshotUploadToCoordinator(loadedVersion)
 
-      // Find the most recent map before this version that we can.
-      // [SPARK-22305] This must be done iteratively to avoid stack overflow.
-      var lastAvailableVersion = version
-      var lastAvailableMap: Option[HDFSBackedStateStoreMap] = None
-      while (lastAvailableMap.isEmpty) {
-        lastAvailableVersion -= 1
+      finalMap
+    }
 
-        if (lastAvailableVersion <= 0) {
+    logDebug(s"Loading state for $version takes $elapsedMs ms.")
+
+    result
+  }
+
+  /** Loads the latest snapshot for the version we want to load and
+   * returns the snapshot version and map representing the snapshot */
+  private def loadSnapshot(versionToLoad: Long): (Long, 
HDFSBackedStateStoreMap) = {
+    var loadedMap: Option[HDFSBackedStateStoreMap] = None
+    val storeIdStr = s"StateStoreId(opId=${stateStoreId_.operatorId}," +
+      s"partId=${stateStoreId_.partitionId},name=${stateStoreId_.storeName})"
+
+    val snapshotLoader = new AutoSnapshotLoader(
+      storeConf.autoSnapshotRepairEnabled,
+      storeConf.autoSnapshotRepairNumFailuresBeforeActivating,
+      storeConf.autoSnapshotRepairMaxChangeFileReplay,
+      storeIdStr) {
+      override protected def beforeLoad(): Unit = {}
+
+      override protected def loadSnapshotFromCheckpoint(snapshotVersion: 
Long): Unit = {
+        loadedMap = if (snapshotVersion <= 0) {
           // Use an empty map for versions 0 or less.
-          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey))
+          Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
         } else {
-          lastAvailableMap =
-            synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
-              .orElse(readSnapshotFile(lastAvailableVersion))
+          // first try to get the map from the cache
+          synchronized { Option(loadedMaps.get(snapshotVersion)) }
+            .orElse(readSnapshotFile(snapshotVersion))
         }
       }
 
-      // Load all the deltas from the version after the last available one up 
to the target version.
-      // The last available version is the one with a full snapshot, so it 
doesn't need deltas.
-      val resultMap = HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey)
-      resultMap.putAll(lastAvailableMap.get)
-      for (deltaVersion <- lastAvailableVersion + 1 to version) {
-        updateFromDeltaFile(deltaVersion, resultMap)
-      }
+      override protected def onLoadSnapshotFromCheckpointFailure(): Unit = {}
 
-      synchronized { putStateIntoStateCacheMap(version, resultMap) }
+      override protected def getEligibleSnapshots(versionToLoad: Long): 
Seq[Long] = {
+        val snapshotVersions = 
SnapshotLoaderHelper.getEligibleSnapshotsForVersion(
+          versionToLoad, fm, baseDir, onlySnapshotFiles, fileSuffix = 
".snapshot")
 
-      // Report the last available snapshot's version to the coordinator
-      reportSnapshotUploadToCoordinator(lastAvailableVersion)
+        // Get locally cached versions, so we can use the locally cached 
version if available.
+        val cachedVersions = synchronized {
+          loadedMaps.keySet.asScala.toSeq
+        }.filter(_ <= versionToLoad)
 
-      resultMap
+        // Combine the two sets of versions, so we can check both during load
+        (snapshotVersions ++ cachedVersions).distinct
+      }
     }
 
-    logDebug(s"Loading state for $version takes $elapsedMs ms.")
-
-    result
+    val (loadedVersion, autoRepairCompleted) = 
snapshotLoader.loadSnapshot(versionToLoad)
+    performedSnapshotAutoRepair.set(autoRepairCompleted)
+    (loadedVersion, loadedMap.get)
   }
 
   private def writeUpdateToDeltaFile(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index fb3ef606b8f3..f8570c583387 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -219,6 +219,9 @@ class RocksDB(
   @volatile private var numInternalKeysOnLoadedVersion = 0L
   @volatile private var numInternalKeysOnWritingVersion = 0L
 
+  // Was snapshot auto repair performed when loading the current version
+  @volatile private var performedSnapshotAutoRepair = false
+
   @volatile private var fileManagerMetrics = 
RocksDBFileManagerMetrics.EMPTY_METRICS
 
   // SPARK-46249 - Keep track of recorded metrics per version which can be 
used for querying later
@@ -541,24 +544,9 @@ class RocksDB(
     try {
       if (loadedVersion != version) {
         closeDB(ignoreException = false)
-        val latestSnapshotVersion = 
fileManager.getLatestSnapshotVersion(version)
-        val metadata = fileManager.loadCheckpointFromDfs(
-          latestSnapshotVersion,
-          workingDir,
-          rocksDBFileMapping)
-
-        loadedVersion = latestSnapshotVersion
-
-        // reset the last snapshot version to the latest available snapshot 
version
-        lastSnapshotVersion = latestSnapshotVersion
-
-        // Initialize maxVersion upon successful load from DFS
-        fileManager.setMaxSeenVersion(version)
-
-        // Report this snapshot version to the coordinator
-        reportSnapshotUploadToCoordinator(latestSnapshotVersion)
 
-        openLocalRocksDB(metadata)
+        // load the latest snapshot
+        loadSnapshotWithoutCheckpointId(version)
 
         if (loadedVersion != version) {
           val versionsAndUniqueIds: Array[(Long, Option[String])] =
@@ -589,6 +577,54 @@ class RocksDB(
     this
   }
 
+  private def loadSnapshotWithoutCheckpointId(versionToLoad: Long): Long = {
+    // Don't allow auto snapshot repair if changelog checkpointing is not 
enabled
+    // since it relies on changelog to rebuild state.
+    val allowAutoSnapshotRepair = if (enableChangelogCheckpointing) {
+      conf.stateStoreConf.autoSnapshotRepairEnabled
+    } else {
+      false
+    }
+    val snapshotLoader = new AutoSnapshotLoader(
+      allowAutoSnapshotRepair,
+      conf.stateStoreConf.autoSnapshotRepairNumFailuresBeforeActivating,
+      conf.stateStoreConf.autoSnapshotRepairMaxChangeFileReplay,
+      loggingId) {
+      override protected def beforeLoad(): Unit = closeDB(ignoreException = 
false)
+
+      override protected def loadSnapshotFromCheckpoint(snapshotVersion: 
Long): Unit = {
+        val remoteMetaData = fileManager.loadCheckpointFromDfs(snapshotVersion,
+          workingDir, rocksDBFileMapping)
+
+        loadedVersion = snapshotVersion
+        // Initialize maxVersion upon successful load from DFS
+        fileManager.setMaxSeenVersion(snapshotVersion)
+
+        openLocalRocksDB(remoteMetaData)
+
+        // By setting this to the snapshot version we successfully loaded,
+        // if auto snapshot repair is enabled, and we end up skipping the 
latest snapshot
+        // and used an older one, we will create a new snapshot at commit time
+        // if the loaded one is old enough.
+        lastSnapshotVersion = snapshotVersion
+        // Report this snapshot version to the coordinator
+        reportSnapshotUploadToCoordinator(snapshotVersion)
+      }
+
+      override protected def onLoadSnapshotFromCheckpointFailure(): Unit = {
+        loadedVersion = -1  // invalidate loaded data
+      }
+
+      override protected def getEligibleSnapshots(version: Long): Seq[Long] = {
+        fileManager.getEligibleSnapshotsForVersion(version)
+      }
+    }
+
+    val (version, autoRepairCompleted) = 
snapshotLoader.loadSnapshot(versionToLoad)
+    performedSnapshotAutoRepair = autoRepairCompleted
+    version
+  }
+
   /**
    * Function to check if col family is internal or not based on information 
recorded in
    * checkpoint metadata.
@@ -657,6 +693,7 @@ class RocksDB(
 
     assert(version >= 0)
     recordedMetrics = None
+    performedSnapshotAutoRepair = false
     // Reset the load metrics before loading
     loadMetrics.clear()
 
@@ -1622,7 +1659,8 @@ class RocksDB(
       filesReused = fileManagerMetrics.filesReused,
       lastUploadedSnapshotVersion = lastUploadedSnapshotVersion.get(),
       zipFileBytesUncompressed = fileManagerMetrics.zipFileBytesUncompressed,
-      nativeOpsMetrics = nativeOpsMetrics)
+      nativeOpsMetrics = nativeOpsMetrics,
+      numSnapshotsAutoRepaired = if (performedSnapshotAutoRepair) 1 else 0)
   }
 
   /**
@@ -2067,7 +2105,8 @@ case class RocksDBConf(
     compression: String,
     reportSnapshotUploadLag: Boolean,
     fileChecksumEnabled: Boolean,
-    maxVersionsToDeletePerMaintenance: Int)
+    maxVersionsToDeletePerMaintenance: Int,
+    stateStoreConf: StateStoreConf)
 
 object RocksDBConf {
   /** Common prefix of all confs in SQLConf that affects RocksDB */
@@ -2267,7 +2306,8 @@ object RocksDBConf {
       getStringConf(COMPRESSION_CONF),
       storeConf.reportSnapshotUploadLag,
       storeConf.checkpointFileChecksumEnabled,
-      storeConf.maxVersionsToDeletePerMaintenance)
+      storeConf.maxVersionsToDeletePerMaintenance,
+      storeConf)
   }
 
   def apply(): RocksDBConf = apply(new StateStoreConf())
@@ -2289,7 +2329,8 @@ case class RocksDBMetrics(
     filesReused: Long,
     zipFileBytesUncompressed: Option[Long],
     nativeOpsMetrics: Map[String, Long],
-    lastUploadedSnapshotVersion: Long) {
+    lastUploadedSnapshotVersion: Long,
+    numSnapshotsAutoRepaired: Long) {
   def json: String = Serialization.write(this)(RocksDBMetrics.format)
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 2e86ff70d58f..92fa5d0350fa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -427,6 +427,12 @@ class RocksDBFileManager(
     }
   }
 
+  /** Get all the snapshot versions that can be used to load this version */
+  def getEligibleSnapshotsForVersion(version: Long): Seq[Long] = {
+    SnapshotLoaderHelper.getEligibleSnapshotsForVersion(
+      version, fm, new Path(dfsRootDir), onlyZipFiles, fileSuffix = ".zip")
+  }
+
   /**
    * Based on the ground truth lineage loaded from changelog file (lineage), 
this function
    * does file listing to find all snapshot (version, uniqueId) pairs, and 
finds
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 2cc4c8a870ae..e01e1e0f86ca 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -552,7 +552,8 @@ private[sql] class RocksDBStateStoreProvider
           CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE -> 
rocksDBMetrics.pinnedBlocksMemUsage,
           CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS -> 
rocksDBMetrics.numInternalKeys,
           CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES -> internalColFamilyCnt(),
-          CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES -> externalColFamilyCnt()
+          CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES -> externalColFamilyCnt(),
+          CUSTOM_METRIC_NUM_SNAPSHOTS_AUTO_REPAIRED -> 
rocksDBMetrics.numSnapshotsAutoRepaired
         ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
           Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> 
bytes)).getOrElse(Map())
 
@@ -1261,6 +1262,9 @@ object RocksDBStateStoreProvider {
   // Total SST file size
   val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric(
     "rocksdbSstFileSize", "RocksDB: size of all SST files")
+  val CUSTOM_METRIC_NUM_SNAPSHOTS_AUTO_REPAIRED = StateStoreCustomSumMetric(
+    "rocksdbNumSnapshotsAutoRepaired",
+    "RocksDB: number of snapshots that were automatically repaired during 
store load")
 
   val ALL_CUSTOM_METRICS = Seq(
     CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, 
CUSTOM_METRIC_PUT_TIME,
@@ -1276,7 +1280,7 @@ object RocksDBStateStoreProvider {
     CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, 
CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS,
     CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, 
CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES,
     CUSTOM_METRIC_LOAD_FROM_SNAPSHOT_TIME, CUSTOM_METRIC_LOAD_TIME, 
CUSTOM_METRIC_REPLAY_CHANGE_LOG,
-    CUSTOM_METRIC_NUM_REPLAY_CHANGE_LOG_FILES)
+    CUSTOM_METRIC_NUM_REPLAY_CHANGE_LOG_FILES, 
CUSTOM_METRIC_NUM_SNAPSHOTS_AUTO_REPAIRED)
 
   val CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED = 
StateStoreSnapshotLastUploadInstanceMetric()
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index 74904a37f450..a765f52a2272 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -48,6 +48,17 @@ class StateStoreConf(
    */
   val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot
 
+  /** Whether we should enable automatic snapshot repair */
+  val autoSnapshotRepairEnabled: Boolean = 
sqlConf.stateStoreAutoSnapshotRepairEnabled
+
+  /** Number of failures before activating auto snapshot repair when enabled */
+  val autoSnapshotRepairNumFailuresBeforeActivating: Int =
+    sqlConf.stateStoreAutoSnapshotRepairNumFailuresBeforeActivating
+
+  /** Maximum number of change files allowed to be replayed when auto snapshot 
repair is enabled */
+  val autoSnapshotRepairMaxChangeFileReplay: Int =
+    sqlConf.stateStoreAutoSnapshotRepairMaxChangeFileReplay
+
   /** Minimum versions a State Store implementation should retain to allow 
rollbacks */
   val minVersionsToRetain: Int = sqlConf.minBatchesToRetain
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index 970499a054b5..23bb54d86348 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -254,6 +254,16 @@ object StateStoreErrors {
     new StateStoreUnexpectedEmptyFileInRocksDBZip(fileName, zipFileName)
   }
 
+  def autoSnapshotRepairFailed(
+      stateStoreId: String,
+      latestSnapshot: Long,
+      selectedSnapshots: Seq[Long],
+      eligibleSnapshots: Seq[Long],
+      cause: Throwable): StateStoreAutoSnapshotRepairFailed = {
+    new StateStoreAutoSnapshotRepairFailed(
+      stateStoreId, latestSnapshot, selectedSnapshots, eligibleSnapshots, 
cause)
+  }
+
   def cannotLoadStore(e: Throwable): Throwable = {
     e match {
       case e: SparkException
@@ -583,3 +593,18 @@ class StateStoreUnexpectedEmptyFileInRocksDBZip(fileName: 
String, zipFileName: S
       "fileName" -> fileName,
       "zipFileName" -> zipFileName),
     cause = null)
+
+class StateStoreAutoSnapshotRepairFailed(
+    stateStoreId: String,
+    latestSnapshot: Long,
+    selectedSnapshots: Seq[Long],
+    eligibleSnapshots: Seq[Long],
+    cause: Throwable)
+  extends SparkRuntimeException(
+    errorClass = "CANNOT_LOAD_STATE_STORE.AUTO_SNAPSHOT_REPAIR_FAILED",
+    messageParameters = Map(
+      "latestSnapshot" -> latestSnapshot.toString,
+      "stateStoreId" -> stateStoreId,
+      "selectedSnapshots" -> selectedSnapshots.mkString(","),
+      "eligibleSnapshots" -> eligibleSnapshots.mkString(",")),
+    cause)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoaderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoaderSuite.scala
new file mode 100644
index 000000000000..186248b43bc8
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoaderSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Suite to test [[AutoSnapshotLoader]]. Tests different behaviors including
+ * when repair is enabled/disabled, when numFailuresBeforeActivating is set,
+ * when maxChangeFileReplay is set.
+ */
+class AutoSnapshotLoaderSuite extends SparkFunSuite {
+  test("successful snapshot load without auto repair") {
+    // Test auto repair on or off
+    Seq(true, false).foreach { enabled =>
+      val loader = new TestAutoSnapshotLoader(
+        autoSnapshotRepairEnabled = enabled,
+        eligibleSnapshots = Seq(2, 4),
+        failSnapshots = Seq.empty)
+
+      val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5)
+      assert(!autoRepairCompleted)
+      assert(versionLoaded == 4, "Should load the latest snapshot version")
+      assert(loader.getRequestedSnapshotVersions == Seq(4),
+        "Should have requested only the latest snapshot version")
+    }
+  }
+
+  test("snapshot load failure gets repaired") {
+    def createLoader(autoRepair: Boolean): TestAutoSnapshotLoader =
+      new TestAutoSnapshotLoader(
+        autoSnapshotRepairEnabled = autoRepair,
+        eligibleSnapshots = Seq(2, 4),
+        failSnapshots = Seq(4))
+
+    // load without auto repair enabled
+    var loader = createLoader(autoRepair = false)
+
+    // This should fail to load v5 due to snapshot 4 failure, even though 
snapshot 2 exists
+    val ex = intercept[TestLoadException] {
+      loader.loadSnapshot(5)
+    }
+    assert(ex.snapshotVersion == 4, "Load failure should be due to version 4")
+
+    // Now try to load with auto repair enabled
+    loader = createLoader(autoRepair = true)
+    val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5)
+    assert(autoRepairCompleted)
+    assert(versionLoaded == 2, "Should have loaded the snapshot version before 
the corrupt one")
+    assert(loader.getRequestedSnapshotVersions == Seq(4, 2))
+  }
+
+  test("repair works even when all snapshots are corrupt") {
+    val loader = new TestAutoSnapshotLoader(
+      autoSnapshotRepairEnabled = true,
+      eligibleSnapshots = Seq(2, 4),
+      failSnapshots = Seq(2, 4))
+
+    val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5)
+    assert(autoRepairCompleted)
+    assert(versionLoaded == 0, "Load 0 since no good snapshots available")
+    assert(loader.getRequestedSnapshotVersions == Seq(4, 2, 0))
+  }
+
+  test("number of failures before activating auto repair") {
+    def createLoader(numFailures: Int): TestAutoSnapshotLoader =
+      new TestAutoSnapshotLoader(
+        autoSnapshotRepairEnabled = true,
+        numFailuresBeforeActivating = numFailures,
+        eligibleSnapshots = Seq(2, 4),
+        failSnapshots = Seq(4))
+
+    (1 to 5).foreach { numFailures =>
+      val loader = createLoader(numFailures)
+      val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5)
+      assert(autoRepairCompleted)
+      assert(versionLoaded == 2, "Should have loaded the snapshot version 
before the corrupt one")
+      assert(loader.getRequestedSnapshotVersions == Seq.fill(numFailures)(4) 
:+ 2,
+        s"should have tried to load version 4 $numFailures times before 
falling back to version 2")
+    }
+  }
+
+  test("maximum change file replay") {
+    def createLoader(maxChangeFileReplay: Int, fail: Seq[Long]): 
TestAutoSnapshotLoader =
+      new TestAutoSnapshotLoader(
+        autoSnapshotRepairEnabled = true,
+        maxChangeFileReplay = maxChangeFileReplay,
+        eligibleSnapshots = Seq(2, 4, 5),
+        failSnapshots = fail)
+
+    var loader = createLoader(maxChangeFileReplay = 1, fail = Seq(5))
+    // repair with max change file replay = 1, should load snapshot 4
+    val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5)
+    assert(autoRepairCompleted)
+    assert(versionLoaded == 4)
+    assert(loader.getRequestedSnapshotVersions == Seq(5, 4))
+
+    // repair with max change file replay = 2, should fail since we can't use 
the older snapshots
+    loader = createLoader(maxChangeFileReplay = 2, fail = Seq(5, 4))
+    val ex = intercept[StateStoreAutoSnapshotRepairFailed] {
+      loader.loadSnapshot(5)
+    }
+
+    checkError(
+      exception = ex,
+      condition = "CANNOT_LOAD_STATE_STORE.AUTO_SNAPSHOT_REPAIR_FAILED",
+      parameters = Map(
+        "latestSnapshot" -> "5",
+        "stateStoreId" -> "test",
+        "selectedSnapshots" -> "4", // only selected 4 due to 
maxChangeFileReplay = 2
+        "eligibleSnapshots" -> "4,2,0")
+    )
+    assert(loader.getRequestedSnapshotVersions == Seq(5, 4))
+    assert(ex.getCause.asInstanceOf[TestLoadException].snapshotVersion == 4)
+
+    // repair with max change file replay = 3, should load snapshot 2
+    loader = createLoader(maxChangeFileReplay = 3, fail = Seq(5, 4))
+    val (versionLoaded_, autoRepairCompleted_) = loader.loadSnapshot(5)
+    assert(autoRepairCompleted_)
+    assert(versionLoaded_ == 2)
+    assert(loader.getRequestedSnapshotVersions == Seq(5, 4, 2))
+  }
+}
+
+/**
+ * A test implementation of [[AutoSnapshotLoader]] for testing purposes.
+ * Allows tracking of requested snapshot versions and simulating load failures.
+ * */
+class TestAutoSnapshotLoader(
+    autoSnapshotRepairEnabled: Boolean,
+    numFailuresBeforeActivating: Int = 1,
+    maxChangeFileReplay: Int = 10,
+    loggingId: String = "test",
+    eligibleSnapshots: Seq[Long],
+    failSnapshots: Seq[Long] = Seq.empty) extends AutoSnapshotLoader(
+  autoSnapshotRepairEnabled, numFailuresBeforeActivating, maxChangeFileReplay, 
loggingId) {
+
+  // track snapshot versions requested via loadSnapshotFromCheckpoint
+  private val requestedSnapshotVersions = ListBuffer[Long]()
+  def getRequestedSnapshotVersions: Seq[Long] = requestedSnapshotVersions.toSeq
+
+  override protected def beforeLoad(): Unit = {}
+
+  override protected def loadSnapshotFromCheckpoint(snapshotVersion: Long): 
Unit = {
+    // Track the snapshot version
+    requestedSnapshotVersions += snapshotVersion
+
+    // throw exception if the snapshot version is in the failSnapshots list
+    if (failSnapshots.contains(snapshotVersion)) {
+      throw new TestLoadException(snapshotVersion)
+    }
+  }
+
+  override protected def onLoadSnapshotFromCheckpointFailure(): Unit = {}
+
+  override protected def getEligibleSnapshots(versionToLoad: Long): Seq[Long] 
= eligibleSnapshots
+}
+
+class TestLoadException(val snapshotVersion: Long)
+  extends IllegalStateException(s"Cannot load snapshot version 
$snapshotVersion")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 38e5b15465b8..0bf95ce92797 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -110,10 +110,13 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
               "rocksdbTotalBytesReadThroughIterator", 
"rocksdbTotalBytesWrittenByFlush",
               "rocksdbPinnedBlocksMemoryUsage", 
"rocksdbNumInternalColFamiliesKeys",
               "rocksdbNumExternalColumnFamilies", 
"rocksdbNumInternalColumnFamilies",
+              "rocksdbNumSnapshotsAutoRepaired",
               "SnapshotLastUploaded.partition_0_default", 
"rocksdbChangeLogWriterCommitLatencyMs",
               "rocksdbSaveZipFilesLatencyMs", 
"rocksdbLoadFromSnapshotLatencyMs",
               "rocksdbLoadLatencyMs", "rocksdbReplayChangeLogLatencyMs",
               "rocksdbNumReplayChangelogFiles"))
+            
assert(stateOperatorMetrics.customMetrics.get("rocksdbNumSnapshotsAutoRepaired")
 == 0,
+              "Should be 0 since we didn't repair any snapshot")
           }
         } finally {
           query.stop()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index da6c3e62798e..de16aa38fe5d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -3692,6 +3692,83 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }
   }
 
+  testWithChangelogCheckpointingEnabled("Auto snapshot repair") {
+    withSQLConf(
+      SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> false.toString,
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2"
+    ) {
+      withTempDir { dir =>
+        val remoteDir = dir.getCanonicalPath
+        withDB(remoteDir) { db =>
+          db.load(0)
+          db.put("a", "0")
+          db.commit()
+          assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0)
+
+          db.load(1)
+          db.put("b", "1")
+          db.commit() // snapshot is created
+          assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0)
+          db.doMaintenance() // upload snapshot 2.zip
+
+          db.load(2)
+          db.put("c", "2")
+          db.commit()
+          assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0)
+
+          db.load(3)
+          db.put("d", "3")
+          db.commit() // snapshot is created
+          assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0)
+          db.doMaintenance() // upload snapshot 4.zip
+        }
+
+        def corruptFile(file: File): Unit =
+          // overwrite the file content to become empty
+          new PrintWriter(file) { close() }
+
+        // corrupt snapshot 4.zip
+        corruptFile(new File(remoteDir, "4.zip"))
+
+        withDB(remoteDir) { db =>
+          // this should fail when trying to load from remote
+          val ex = intercept[java.nio.file.NoSuchFileException] {
+            db.load(4)
+          }
+          // would fail while trying to read the metadata file from the empty 
zip file
+          assert(ex.getMessage.contains("/metadata"))
+        }
+
+        // Enable auto snapshot repair
+        withSQLConf(SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED.key -> 
true.toString,
+          
SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING.key -> 
"1",
+          SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY.key 
-> "5"
+        ) {
+          withDB(remoteDir) { db =>
+            // this should now succeed
+            db.load(4)
+            assert(toStr(db.get("a")) == "0")
+            db.put("e", "4")
+            db.commit() // a new snapshot (5.zip) will be created since 
previous one is corrupt
+            assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 1)
+            db.doMaintenance() // upload snapshot 5.zip
+          }
+
+          // corrupt all snapshot files
+          Seq(2, 5).foreach { v => corruptFile(new File(remoteDir, s"$v.zip")) 
}
+
+          withDB(remoteDir) { db =>
+            // this load should succeed due to auto repair, even though all 
snapshots are bad
+            db.load(5)
+            assert(toStr(db.get("b")) == "1")
+            db.commit()
+            assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 1)
+          }
+        }
+      }
+    }
+  }
+
   testWithChangelogCheckpointingEnabled("SPARK-51922 - Changelog writer v1 
with large key" +
     " does not cause UTFDataFormatException") {
     val remoteDir = Utils.createTempDir()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 6bb64315e356..807397d96918 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, 
IOException, ObjectInputStream, ObjectOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, 
IOException, ObjectInputStream, ObjectOutputStream, PrintWriter}
 import java.net.URI
 import java.util
 import java.util.UUID
@@ -1420,6 +1420,92 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 
1"))
   }
 
+  test("Auto snapshot repair") {
+    withSQLConf(
+      SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> false.toString,
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1" // for hdfs means 
every 2 versions
+    ) {
+      val storeId = StateStoreId(newDir(), 0L, 1)
+      val remoteDir = storeId.storeCheckpointLocation().toString
+
+      def numSnapshotsAutoRepaired(store: StateStore): Long = {
+        store.metrics.customMetrics
+          .find(m => m._1.name == "numSnapshotsAutoRepaired").get._2
+      }
+
+      tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { 
provider =>
+        var store = provider.getStore(0)
+        put(store, "a", 0, 0)
+        store.commit()
+        assert(numSnapshotsAutoRepaired(store) == 0)
+
+        store = provider.getStore(1)
+        put(store, "b", 1, 1)
+        store.commit()
+        assert(numSnapshotsAutoRepaired(store) == 0)
+        provider.doMaintenance() // upload snapshot 2.snapshot
+
+        store = provider.getStore(2)
+        put(store, "c", 2, 2)
+        store.commit()
+        assert(numSnapshotsAutoRepaired(store) == 0)
+
+        store = provider.getStore(3)
+        put(store, "d", 3, 3)
+        store.commit()
+        assert(numSnapshotsAutoRepaired(store) == 0)
+        provider.doMaintenance() // upload snapshot 4.snapshot
+      }
+
+      def corruptFile(file: File): Unit =
+        // overwrite the file content to become empty
+        new PrintWriter(file) { close() }
+
+      // corrupt 4.snapshot
+      corruptFile(new File(remoteDir, "4.snapshot"))
+
+      tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { 
provider =>
+        // this should fail when trying to load from remote
+        val ex = intercept[SparkException] {
+          provider.getStore(4)
+        }
+        assert(ex.getCause.isInstanceOf[java.io.EOFException])
+      }
+
+      // Enable auto snapshot repair
+      withSQLConf(SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED.key -> 
true.toString,
+        
SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING.key -> 
"1",
+        SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY.key -> 
"6"
+      ) {
+        tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { 
provider =>
+          // this should now succeed
+          var store = provider.getStore(4)
+          assert(get(store, "a", 0).contains(0))
+          put(store, "e", 4, 4)
+          store.commit()
+          assert(numSnapshotsAutoRepaired(store) == 1)
+
+          store = provider.getStore(5)
+          put(store, "f", 5, 5)
+          store.commit()
+          assert(numSnapshotsAutoRepaired(store) == 0)
+          provider.doMaintenance() // upload snapshot 6.snapshot
+        }
+
+        // corrupt all snapshot files
+        Seq(2, 6).foreach { v => corruptFile(new File(remoteDir, 
s"$v.snapshot"))}
+
+        tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { 
provider =>
+          // this load should succeed due to auto repair, even though all 
snapshots are bad
+          val store = provider.getStore(6)
+          assert(get(store, "b", 1).contains(1))
+          store.commit()
+          assert(numSnapshotsAutoRepaired(store) == 1)
+        }
+      }
+    }
+  }
+
   override def newStoreProvider(): HDFSBackedStateStoreProvider = {
     newStoreProvider(opId = Random.nextInt(), partition = 0)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to