This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9bc8cd51b671 [SPARK-51358][SS] Introduce snapshot upload lag detection through StateStoreCoordinator 9bc8cd51b671 is described below commit 9bc8cd51b671aeab7f4652698c00a72991e216fb Author: Zeyu Chen <zyc...@gmail.com> AuthorDate: Fri Apr 11 10:05:28 2025 +0900 [SPARK-51358][SS] Introduce snapshot upload lag detection through StateStoreCoordinator ### What changes were proposed in this pull request? SPARK-51358 This PR adds detection logic + logging to detect delays in snapshot uploads across all state store instances (both RocksDB and HDFSBacked). The main snapshot upload reporting logic is done through RPC calls from RocksDB.scala and HDFSStateStoreProvider to the StateStoreCoordinator, so that events are not dependent on streaming query progress reports. ### Why are the changes needed? This allows us to enable observability through dashboards and alerts, helping us understand the frequency of lag in production. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Ten new tests are added in StateStoreCoordinatorSuite (4 RocksDB, 4 HDFS, 3 edge cases), while taking consideration join and non-joining stateful queries. One of these test is used to verify that the coordinator will not report any state store as lagging if changelog checkpointing is disabled. The other test verifies that query restarts do not cause the coordinator to report every instance as lagging. Another test uses AvailableNow triggers in short batch durations to verify that repeated starts/stops can still report lagging state stores. ### Was this patch authored or co-authored using generative AI tooling? No Closes #50123 from zecookiez/SPARK-51358. Lead-authored-by: Zeyu Chen <zyc...@gmail.com> Co-authored-by: Zeyu Chen <zyc...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../scala/org/apache/spark/internal/LogKey.scala | 4 + .../org/apache/spark/sql/internal/SQLConf.scala | 79 +++ .../spark/sql/classic/StreamingQueryManager.scala | 2 +- .../execution/streaming/IncrementalExecution.scala | 3 +- .../execution/streaming/MicroBatchExecution.scala | 3 +- .../sql/execution/streaming/ProgressReporter.scala | 18 + .../state/HDFSBackedStateStoreProvider.scala | 24 +- .../sql/execution/streaming/state/RocksDB.scala | 30 +- .../state/RocksDBStateStoreProvider.scala | 58 +- .../sql/execution/streaming/state/StateStore.scala | 51 +- .../execution/streaming/state/StateStoreConf.scala | 6 + .../streaming/state/StateStoreCoordinator.scala | 257 ++++++++- .../org/apache/spark/sql/JavaDatasetSuite.java | 3 + .../FailureInjectionCheckpointFileManager.scala | 12 +- .../RocksDBCheckpointFailureInjectionSuite.scala | 3 +- .../streaming/state/RocksDBStateStoreSuite.scala | 1 + .../state/StateStoreCoordinatorSuite.scala | 618 ++++++++++++++++++++- .../streaming/state/ValueStateSuite.scala | 1 + 18 files changed, 1134 insertions(+), 39 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index be3ba751af69..1f997592dbfb 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -511,6 +511,7 @@ private[spark] object LogKeys { case object NUM_ITERATIONS extends LogKey case object NUM_KAFKA_PULLS extends LogKey case object NUM_KAFKA_RECORDS_PULLED extends LogKey + case object NUM_LAGGING_STORES extends LogKey case object NUM_LEADING_SINGULAR_VALUES extends LogKey case object NUM_LEFT_PARTITION_VALUES extends LogKey case object NUM_LOADED_ENTRIES extends LogKey @@ -751,6 +752,9 @@ private[spark] object LogKeys { case object SLEEP_TIME extends LogKey case object SLIDE_DURATION extends LogKey case object SMALLEST_CLUSTER_INDEX extends LogKey + case object SNAPSHOT_EVENT extends LogKey + case object SNAPSHOT_EVENT_TIME_DELTA extends LogKey + case object SNAPSHOT_EVENT_VERSION_DELTA extends LogKey case object SNAPSHOT_VERSION extends LogKey case object SOCKET_ADDRESS extends LogKey case object SOURCE extends LogKey diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 93acb39944fa..c8571a58f9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2332,6 +2332,70 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG = + buildConf("spark.sql.streaming.stateStore.multiplierForMinVersionDiffToLog") + .internal() + .doc( + "Determines the version threshold for logging warnings when a state store falls behind. " + + "The coordinator logs a warning when the store's uploaded snapshot version trails the " + + "query's latest version by the configured number of deltas needed to create a snapshot, " + + "times this multiplier." + ) + .version("4.1.0") + .longConf + .checkValue(k => k >= 1L, "Must be greater than or equal to 1") + .createWithDefault(5L) + + val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG = + buildConf("spark.sql.streaming.stateStore.multiplierForMinTimeDiffToLog") + .internal() + .doc( + "Determines the time threshold for logging warnings when a state store falls behind. " + + "The coordinator logs a warning when the store's uploaded snapshot timestamp trails the " + + "current time by the configured maintenance interval, times this multiplier." + ) + .version("4.1.0") + .longConf + .checkValue(k => k >= 1L, "Must be greater than or equal to 1") + .createWithDefault(10L) + + val STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG = + buildConf("spark.sql.streaming.stateStore.coordinatorReportSnapshotUploadLag") + .internal() + .doc( + "When enabled, the state store coordinator will report state stores whose snapshot " + + "have not been uploaded for some time. See the conf snapshotLagReportInterval for " + + "the minimum time between reports, and the conf multiplierForMinVersionDiffToLog " + + "and multiplierForMinTimeDiffToLog for the logging thresholds." + ) + .version("4.1.0") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL = + buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval") + .internal() + .doc( + "The minimum amount of time between the state store coordinator's reports on " + + "state store instances trailing behind in snapshot uploads." + ) + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(5)) + + val STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT = + buildConf("spark.sql.streaming.stateStore.maxLaggingStoresToReport") + .internal() + .doc( + "Maximum number of state stores the coordinator will report as trailing in " + + "snapshot uploads. Stores are selected based on the most lagging behind in " + + "snapshot version." + ) + .version("4.1.0") + .intConf + .checkValue(k => k >= 0, "Must be greater than or equal to 0") + .createWithDefault(5) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() @@ -5931,6 +5995,21 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreSkipNullsForStreamStreamJoins: Boolean = getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS) + def stateStoreCoordinatorMultiplierForMinVersionDiffToLog: Long = + getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) + + def stateStoreCoordinatorMultiplierForMinTimeDiffToLog: Long = + getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) + + def stateStoreCoordinatorReportSnapshotUploadLag: Boolean = + getConf(STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG) + + def stateStoreCoordinatorSnapshotLagReportInterval: Long = + getConf(STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + + def stateStoreCoordinatorMaxLaggingStoresToReport: Int = + getConf(STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 6ce6f06de113..6d4a3ecd3603 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] ( with Logging { private[sql] val stateStoreCoordinator = - StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) + StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf) private val listenerBus = new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 246057a5a9d0..8c1e5e901513 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -71,7 +71,8 @@ class IncrementalExecution( MutableMap[Long, Array[Array[String]]] = MutableMap[Long, Array[Array[String]]](), val stateSchemaMetadatas: MutableMap[Long, StateSchemaBroadcast] = MutableMap[Long, StateSchemaBroadcast](), - mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) + mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, + val isTerminatingTrigger: Boolean = false) extends QueryExecution(sparkSession, logicalPlan, mode = mode) with Logging { // Modified planner with stateful operations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index c977a499edc0..1dd70ad985cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -858,7 +858,8 @@ class MicroBatchExecution( watermarkPropagator, execCtx.previousContext.isEmpty, currentStateStoreCkptId, - stateSchemaMetadatas) + stateSchemaMetadatas, + isTerminatingTrigger = trigger.isInstanceOf[AvailableNowTrigger.type]) execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d814a86c84c7..dc04ba3331e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.{QueryExecution, StreamSourceAwareSparkPlan} import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress} +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent} import org.apache.spark.util.{Clock, Utils} @@ -61,6 +62,12 @@ class ProgressReporter( val noDataProgressEventInterval: Long = sparkSession.sessionState.conf.streamingNoDataProgressEventInterval + val coordinatorReportSnapshotUploadLag: Boolean = + sparkSession.sessionState.conf.stateStoreCoordinatorReportSnapshotUploadLag + + val stateStoreCoordinator: StateStoreCoordinatorRef = + sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + private val timestampFormat = DateTimeFormatter .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 @@ -283,6 +290,17 @@ abstract class ProgressContext( progressReporter.lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() progressReporter.updateProgress(newProgress) + // Ask the state store coordinator to log all lagging state stores + if (progressReporter.coordinatorReportSnapshotUploadLag) { + val latestVersion = lastEpochId + 1 + progressReporter.stateStoreCoordinator + .logLaggingStateStores( + lastExecution.runId, + latestVersion, + lastExecution.isTerminatingTrigger + ) + } + // Update the value since this trigger executes a batch successfully. this.execStatsOnLatestExecutedBatch = Some(execStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 648fe0f5b1fd..98d49596d11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.util -import java.util.Locale +import java.util.{Locale, UUID} import java.util.concurrent.atomic.{AtomicLong, LongAdder} import scala.collection.mutable @@ -551,6 +551,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } + + // Report the loaded snapshot's version to the coordinator + reportSnapshotUploadToCoordinator(version) + return snapshotCurrentVersionMap.get } @@ -580,6 +584,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } synchronized { putStateIntoStateCacheMap(version, resultMap) } + + // Report the last available snapshot's version to the coordinator + reportSnapshotUploadToCoordinator(lastAvailableVersion) + resultMap } @@ -699,6 +707,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.OP_TYPE, opType)}") // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(version, v)) + // Report the snapshot upload event to the coordinator + reportSnapshotUploadToCoordinator(version) } /** @@ -1043,6 +1053,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keySchema, valueSchema) } + + /** Reports to the coordinator the store's latest snapshot version */ + private def reportSnapshotUploadToCoordinator(version: Long): Unit = { + if (storeConf.reportSnapshotUploadLag) { + // Attach the query run ID and current timestamp to the RPC message + val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) + val currentTimestamp = System.currentTimeMillis() + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), version, currentTimestamp) + ) + } + } } /** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 15df2fae8260..07553f51c60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -64,6 +64,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param stateStoreId StateStoreId for the state store * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system + * @param eventForwarder The RocksDBEventForwarder object for reporting events to the coordinator */ class RocksDB( dfsRootDir: String, @@ -73,7 +74,8 @@ class RocksDB( loggingId: String = "", useColumnFamilies: Boolean = false, enableStateStoreCheckpointIds: Boolean = false, - partitionId: Int = 0) extends Logging { + partitionId: Int = 0, + eventForwarder: Option[RocksDBEventForwarder] = None) extends Logging { import RocksDB._ @@ -403,6 +405,9 @@ class RocksDB( // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(latestSnapshotVersion) + openLocalRocksDB(metadata) if (loadedVersion != version) { @@ -480,6 +485,9 @@ class RocksDB( // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(latestSnapshotVersion) + openLocalRocksDB(metadata) if (loadedVersion != version) { @@ -617,6 +625,8 @@ class RocksDB( loadedVersion = -1 // invalidate loaded data throw t } + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(snapshotVersion) this } @@ -1495,6 +1505,8 @@ class RocksDB( log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}") // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) + // Report snapshot upload event to the coordinator. + reportSnapshotUploadToCoordinator(snapshot.version) } finally { snapshot.close() } @@ -1502,6 +1514,16 @@ class RocksDB( fileManagerMetrics } + /** Reports to the coordinator with the event listener that a snapshot finished uploading */ + private def reportSnapshotUploadToCoordinator(version: Long): Unit = { + if (conf.reportSnapshotUploadLag) { + // Note that we still report snapshot versions even when changelog checkpointing is disabled. + // The coordinator needs a way to determine whether upload messages are disabled or not, + // which would be different between RocksDB and HDFS stores due to changelog checkpointing. + eventForwarder.foreach(_.reportSnapshotUploaded(version)) + } + } + /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ private def createLogger(): Logger = { val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) { @@ -1768,7 +1790,8 @@ case class RocksDBConf( highPriorityPoolRatio: Double, compressionCodec: String, allowFAllocate: Boolean, - compression: String) + compression: String, + reportSnapshotUploadLag: Boolean) object RocksDBConf { /** Common prefix of all confs in SQLConf that affects RocksDB */ @@ -1951,7 +1974,8 @@ object RocksDBConf { getRatioConf(HIGH_PRIORITY_POOL_RATIO_CONF), storeConf.compressionCodec, getBooleanConf(ALLOW_FALLOCATE_CONF), - getStringConf(COMPRESSION_CONF)) + getStringConf(COMPRESSION_CONF), + storeConf.reportSnapshotUploadLag) } def apply(): RocksDBConf = apply(new StateStoreConf()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 601caaa34290..6a36b8c01519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.state.StateStoreEncoding.Avro import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform @@ -67,7 +67,7 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal) val dataEncoderCacheKey = StateRowEncoderCacheKey( - queryRunId = getRunId(hadoopConf), + queryRunId = StateStoreProvider.getRunId(hadoopConf), operatorId = stateStoreId.operatorId, partitionId = stateStoreId.partitionId, stateStoreName = stateStoreId.storeName, @@ -390,6 +390,8 @@ private[sql] class RocksDBStateStoreProvider this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat this.stateSchemaProvider = stateSchemaProvider + this.rocksDBEventForwarder = + Some(RocksDBEventForwarder(StateStoreProvider.getRunId(hadoopConf), stateStoreId)) if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -399,7 +401,7 @@ private[sql] class RocksDBStateStoreProvider rocksDB // lazy initialization val dataEncoderCacheKey = StateRowEncoderCacheKey( - queryRunId = getRunId(hadoopConf), + queryRunId = StateStoreProvider.getRunId(hadoopConf), operatorId = stateStoreId.operatorId, partitionId = stateStoreId.partitionId, stateStoreName = stateStoreId.storeName, @@ -523,6 +525,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var useColumnFamilies: Boolean = _ @volatile private var stateStoreEncoding: String = _ @volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _ + @volatile private var rocksDBEventForwarder: Option[RocksDBEventForwarder] = _ protected def createRocksDB( dfsRootDir: String, @@ -532,7 +535,8 @@ private[sql] class RocksDBStateStoreProvider loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int = 0): RocksDB = { + partitionId: Int = 0, + eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { new RocksDB( dfsRootDir, conf, @@ -541,7 +545,8 @@ private[sql] class RocksDBStateStoreProvider loggingId, useColumnFamilies, enableStateStoreCheckpointIds, - partitionId) + partitionId, + eventForwarder) } private[sql] lazy val rocksDB = { @@ -551,7 +556,8 @@ private[sql] class RocksDBStateStoreProvider val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, - useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId) + useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId, + rocksDBEventForwarder) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, @@ -822,16 +828,6 @@ object RocksDBStateStoreProvider { ) } - private def getRunId(hadoopConf: Configuration): String = { - val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) - if (runId != null) { - runId - } else { - assert(Utils.isTesting, "Failed to find query id/batch Id in task context") - UUID.randomUUID().toString - } - } - // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( @@ -991,3 +987,33 @@ class RocksDBStateStoreChangeDataReader( } } } + +/** + * Class used to relay events reported from a RocksDB instance to the state store coordinator. + * + * We pass this into the RocksDB instance to report specific events like snapshot uploads. + * This should only be used to report back to the coordinator for metrics and monitoring purposes. + */ +private[state] case class RocksDBEventForwarder(queryRunId: String, stateStoreId: StateStoreId) { + // Build the state store provider ID from the query run ID and the state store ID + private val providerId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) + + /** + * Callback function from RocksDB to report events to the coordinator. + * Information from the store provider such as the state store ID and query run ID are + * attached here to report back to the coordinator. + * + * @param version The snapshot version that was just uploaded from RocksDB + */ + def reportSnapshotUploaded(version: Long): Unit = { + // Report the state store provider ID and the version to the coordinator + val currentTimestamp = System.currentTimeMillis() + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded( + providerId, + version, + currentTimestamp + ) + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ccb925287e77..63936305c7cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -593,7 +593,15 @@ trait StateStoreProvider { def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = Seq.empty } -object StateStoreProvider { +object StateStoreProvider extends Logging { + + /** + * The state store coordinator reference used to report events such as snapshot uploads from + * the state store providers. + * For all other messages, refer to the coordinator reference in the [[StateStore]] object. + */ + @GuardedBy("this") + private var stateStoreCoordinatorRef: StateStoreCoordinatorRef = _ /** * Return a instance of the given provider class name. The instance will not be initialized. @@ -652,6 +660,47 @@ object StateStoreProvider { } } } + + /** + * Get the runId from the provided hadoopConf. If it is not found, generate a random UUID. + * + * @param hadoopConf Hadoop configuration used by the StateStore to save state data + */ + private[state] def getRunId(hadoopConf: Configuration): String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } + + /** + * Create the state store coordinator reference which will be reused across state store providers + * in the executor. + * This coordinator reference should only be used to report events from store providers regarding + * snapshot uploads to avoid lock contention with other coordinator RPC messages. + */ + private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in stateStoreCoordinatorRef may have + // become inactive as SparkContext + SparkEnv may have been restarted. Hence, when running in + // driver, always recreate the reference. + if (isDriver || stateStoreCoordinatorRef == null) { + logDebug("Getting StateStoreCoordinatorRef") + stateStoreCoordinatorRef = StateStoreCoordinatorRef.forExecutor(env) + } + logInfo(log"Retrieved reference to StateStoreCoordinator: " + + log"${MDC(LogKeys.STATE_STORE_COORDINATOR, stateStoreCoordinatorRef)}") + Some(stateStoreCoordinatorRef) + } else { + stateStoreCoordinatorRef = null + None + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 807534ee4569..e0450cfc4f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -97,6 +97,12 @@ class StateStoreConf( val enableStateStoreCheckpointIds = StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf) + /** + * Whether the coordinator is reporting state stores trailing behind in snapshot uploads. + */ + val reportSnapshotUploadLag: Boolean = + sqlConf.stateStoreCoordinatorReportSnapshotUploadLag + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 84b77efea3ca..903f27fb2a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -22,9 +22,10 @@ import java.util.UUID import scala.collection.mutable import org.apache.spark.SparkEnv -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ @@ -55,6 +56,45 @@ private case class GetLocation(storeId: StateStoreProviderId) private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage +/** + * This message is used to report a state store has just finished uploading a snapshot, + * along with the timestamp in milliseconds and the snapshot version. + */ +private case class ReportSnapshotUploaded( + providerId: StateStoreProviderId, + version: Long, + timestamp: Long) + extends StateStoreCoordinatorMessage + +/** + * This message is used for the coordinator to look for all state stores that are lagging behind + * in snapshot uploads. The coordinator will then log a warning message for each lagging instance. + */ +private case class LogLaggingStateStores( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean) + extends StateStoreCoordinatorMessage + +/** + * Message used for testing. + * This message is used to retrieve the latest snapshot version reported for upload from a + * specific state store. + */ +private case class GetLatestSnapshotVersionForTesting(providerId: StateStoreProviderId) + extends StateStoreCoordinatorMessage + +/** + * Message used for testing. + * This message is used to retrieve all active state store instances falling behind in + * snapshot uploads, using version and time criteria. + */ +private case class GetLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean) + extends StateStoreCoordinatorMessage + private object StopCoordinator extends StateStoreCoordinatorMessage @@ -66,9 +106,9 @@ object StateStoreCoordinatorRef extends Logging { /** * Create a reference to a [[StateStoreCoordinator]] */ - def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + def forDriver(env: SparkEnv, sqlConf: SQLConf): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinator = new StateStoreCoordinator(env.rpcEnv, sqlConf) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -119,6 +159,46 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } + /** Inform that an executor has uploaded a snapshot */ + private[sql] def snapshotUploaded( + providerId: StateStoreProviderId, + version: Long, + timestamp: Long): Boolean = { + rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(providerId, version, timestamp)) + } + + /** Ask the coordinator to log all state store instances that are lagging behind in uploads */ + private[sql] def logLaggingStateStores( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean): Boolean = { + rpcEndpointRef.askSync[Boolean]( + LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger)) + } + + /** + * Endpoint used for testing. + * Get the latest snapshot version uploaded for a state store. + */ + private[state] def getLatestSnapshotVersionForTesting( + providerId: StateStoreProviderId): Option[Long] = { + rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(providerId)) + } + + /** + * Endpoint used for testing. + * Get the state store instances that are falling behind in snapshot uploads for a particular + * query run. + */ + private[state] def getLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean = false): Seq[StateStoreProviderId] = { + rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( + GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) + ) + } + private[state] def stop(): Unit = { rpcEndpointRef.askSync[Boolean](StopCoordinator) } @@ -129,10 +209,30 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { +private class StateStoreCoordinator( + override val rpcEnv: RpcEnv, + val sqlConf: SQLConf) + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] + // Stores the latest snapshot upload event for a specific state store + private val stateStoreLatestUploadedSnapshot = + new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + + // Default snapshot upload event to use when a provider has never uploaded a snapshot + private val defaultSnapshotUploadEvent = SnapshotUploadEvent(0, 0) + + // Stores the last timestamp in milliseconds for each queryRunId indicating when the + // coordinator did a report on instances lagging behind on snapshot uploads. + // The initial timestamp is defaulted to 0 milliseconds. + private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] + + private def shouldCoordinatorReportSnapshotLag: Boolean = + sqlConf.stateStoreCoordinatorReportSnapshotUploadLag + + private def coordinatorLagReportInterval: Long = + sqlConf.stateStoreCoordinatorSnapshotLagReportInterval + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -164,13 +264,160 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) val storeIdsToRemove = instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove + // Also remove these instances from snapshot upload event tracking + stateStoreLatestUploadedSnapshot --= storeIdsToRemove + // Remove the corresponding run id entries for report time and starting time + lastFullSnapshotLagReportTimeMs -= runId logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) + case ReportSnapshotUploaded(providerId, version, timestamp) => + // Ignore this upload event if the registered latest version for the store is more recent, + // since it's possible that an older version gets uploaded after a new executor uploads for + // the same state store but with a newer snapshot. + logDebug(s"Snapshot version $version was uploaded for state store $providerId") + if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) { + stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) + } + context.reply(true) + + case LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger) => + val currentTimestamp = System.currentTimeMillis() + // Only log lagging instances if snapshot lag reporting and uploading is enabled, + // otherwise all instances will be considered lagging. + if (shouldCoordinatorReportSnapshotLag) { + val laggingStores = + findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger) + if (laggingStores.nonEmpty) { + logWarning( + log"StateStoreCoordinator Snapshot Lag Report for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + ) + // Report all stores that are behind in snapshot uploads. + // Only report the list of providers lagging behind if the last reported time + // is not recent for this query run. The lag report interval denotes the minimum + // time between these full reports. + val timeSinceLastReport = + currentTimestamp - lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L) + if (timeSinceLastReport > coordinatorLagReportInterval) { + // Mark timestamp of the report and log the lagging instances + lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp) + // Only report the stores that are lagging the most behind in snapshot uploads. + laggingStores + .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) + .take(sqlConf.stateStoreCoordinatorMaxLaggingStoresToReport) + .foreach { providerId => + val baseLogMessage = + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}" + + val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { + case Some(snapshotEvent) => + val versionDelta = latestVersion - snapshotEvent.version + val timeDelta = currentTimestamp - snapshotEvent.timestamp + + baseLogMessage + log", " + + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + + log"version delta: " + + log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + baseLogMessage + log", latest snapshot: no upload for query run)" + } + logWarning(logMessage) + } + } + } + } + context.reply(true) + + case GetLatestSnapshotVersionForTesting(providerId) => + val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version) + logDebug(s"Got latest snapshot version of the state store $providerId: $version") + context.reply(version) + + case GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) => + val currentTimestamp = System.currentTimeMillis() + // Only report if snapshot lag reporting is enabled + if (shouldCoordinatorReportSnapshotLag) { + val laggingStores = + findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger) + logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") + context.reply(laggingStores) + } else { + context.reply(Seq.empty) + } + case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") context.reply(true) } + + private def findLaggingStores( + queryRunId: UUID, + referenceVersion: Long, + referenceTimestamp: Long, + isTerminatingTrigger: Boolean): Seq[StateStoreProviderId] = { + // Determine alert thresholds from configurations for both time and version differences. + val snapshotVersionDeltaMultiplier = + sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog + val maintenanceIntervalMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog + val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot + val maintenanceInterval = sqlConf.streamingMaintenanceInterval + + // Use the configured multipliers multiplierForMinVersionDiffToLog and + // multiplierForMinTimeDiffToLog to determine the proper alert thresholds. + val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot + val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval + + // Look for active state store providers that are lagging behind in snapshot uploads. + // The coordinator should only consider providers that are part of this specific query run. + instances.view.keys + .filter(_.queryRunId == queryRunId) + .filter { storeProviderId => + // Stores that didn't upload a snapshot will be treated as a store with a snapshot of + // version 0 and timestamp 0ms. + val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse( + storeProviderId, + defaultSnapshotUploadEvent + ) + // Mark a state store as lagging if it's behind in both version and time. + // A state store is considered lagging if it's behind in both version and time according + // to the configured thresholds. + val isBehindOnVersions = + referenceVersion - latestSnapshot.version > minVersionDeltaForLogging + val isBehindOnTime = + referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging + // If the query is using a trigger that self-terminates like OneTimeTrigger + // and AvailableNowTrigger, we ignore the time threshold check as the upload frequency + // is not fully dependent on the maintenance interval. + isBehindOnVersions && (isTerminatingTrigger || isBehindOnTime) + }.toSeq + } +} + +case class SnapshotUploadEvent( + version: Long, + timestamp: Long +) extends Ordered[SnapshotUploadEvent] { + + override def compare(otherEvent: SnapshotUploadEvent): Int = { + // Compare by version first, then by timestamp as tiebreaker + val versionCompare = this.version.compare(otherEvent.version) + if (versionCompare == 0) { + this.timestamp.compare(otherEvent.timestamp) + } else { + versionCompare + } + } + + override def toString(): String = { + s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index bffd2c5d9f70..692b5c0ebc3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -63,6 +63,9 @@ public class JavaDatasetSuite implements Serializable { spark = new TestSparkSession(); jsc = new JavaSparkContext(spark.sparkContext()); spark.loadTestData(); + + // Initialize state store coordinator endpoint + spark.streams().stateStoreCoordinator(); } @AfterEach diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala index 4711a45804fb..ebbdd1ad63ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala @@ -251,7 +251,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int): RocksDB = { + partitionId: Int, + eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { FailureInjectionRocksDBStateStoreProvider.createRocksDBWithFaultInjection( dfsRootDir, conf, @@ -260,7 +261,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide loggingId, useColumnFamilies, enableStateStoreCheckpointIds, - partitionId) + partitionId, + eventForwarder) } } @@ -277,7 +279,8 @@ object FailureInjectionRocksDBStateStoreProvider { loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int): RocksDB = { + partitionId: Int, + eventForwarder: Option[RocksDBEventForwarder]): RocksDB = { new RocksDB( dfsRootDir, conf = conf, @@ -286,7 +289,8 @@ object FailureInjectionRocksDBStateStoreProvider { loggingId = loggingId, useColumnFamilies = useColumnFamilies, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, - partitionId = partitionId + partitionId = partitionId, + eventForwarder = eventForwarder ) { override def createFileManager( dfsRootDir: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala index 31fc51c4d56f..5c24ec209036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala @@ -523,7 +523,8 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest loggingId = s"[Thread-${Thread.currentThread.getId}]", useColumnFamilies = true, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, - partitionId = 0) + partitionId = 0, + eventForwarder = None) db.load(version, checkpointId) func(db) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 5aea0077e2aa..b13508682188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -52,6 +52,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid before { StateStore.stop() require(!StateStore.isMaintenanceRunning) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 2ebc533f7137..09118edc4357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -26,8 +26,10 @@ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} -import org.apache.spark.sql.functions.count -import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} +import org.apache.spark.sql.functions.{count, expr} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest, Trigger} import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -102,7 +104,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => - val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) @@ -125,7 +127,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator implicit val sqlContext = spark.sqlContext - spark.conf.set(SHUFFLE_PARTITIONS.key, "1") + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores val inputData = MemoryStream[Int] @@ -155,16 +157,622 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val regularStateStoreProviders = Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName, Set.empty[Int]), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName, Set.empty[Int]) + ) + + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val faultyStateStoreProviders = Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) + ) + ) + + private val allStateStoreProviders = + regularStateStoreProviders ++ faultyStateStoreProviders + + /** + * Verifies snapshot upload RPC messages from state stores are registered and verifies + * the coordinator detected the correct lagging partitions. + */ + private def verifySnapshotUploadEvents( + coordRef: StateStoreCoordinatorRef, + query: StreamingQuery, + badPartitions: Set[Int], + storeNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)): Unit = { + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + // Verify for every store name listed + storeNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) + if (badPartitions.contains(partitionId)) { + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + assert(latestSnapshotVersion.get >= 0) + } + } + } + // Verify that only the bad partitions are all marked as lagging. + // Join queries should have all their state stores marked as lagging, + // which would be 4 stores per partition instead of 1. + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == badPartitions.size * storeNames.size) + assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions) + } + + /** Sets up a stateful dropDuplicate query for testing */ + private def setUpStatefulQuery( + inputData: MemoryStream[Int], queryName: String): StreamingQuery = { + // Set up a stateful drop duplicate query + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName(queryName) + .option("checkpointLocation", checkpointLocation.toString) + .start() + query + } + + allStateStoreProviders.foreach { case (providerName, providerClassName, badPartitions) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + // Verify only the partitions in badPartitions are marked as lagging + verifySnapshotUploadEvents(coordRef, query, badPartitions) + query.stop() + } + } + } + + allStateStoreProviders.foreach { case (providerName, providerClassName, badPartitions) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + // Verify only the partitions in badPartitions are marked as lagging + verifySnapshotUploadEvents(coordRef, query, badPartitions, allJoinStateStoreNames) + query.stop() + } + } + } + + test("SPARK-51358: Verify coordinator properly handles simultaneous query runs") { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start and run two queries together with some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val query1 = setUpStatefulQuery(input1, "query1") + val query2 = setUpStatefulQuery(input2, "query2") + + // Go through several rounds of input to force snapshot uploads for both queries + (0 until 2).foreach { _ => + input1.addData(1, 2, 3) + input2.addData(1, 2, 3) + query1.processAllAvailable() + query2.processAllAvailable() + // Process twice the amount of data for the first query + input1.addData(1, 2, 3) + query1.processAllAvailable() + Thread.sleep(1000) + } + // Verify that the coordinator logged the correct lagging stores for the first query + val streamingQuery1 = query1.asInstanceOf[StreamingQueryWrapper].streamingQuery + val latestVersion1 = streamingQuery1.lastProgress.batchId + 1 + val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, latestVersion1) + + assert(laggingStores1.size == 2) + assert(laggingStores1.forall(_.storeId.partitionId <= 1)) + assert(laggingStores1.forall(_.queryRunId == query1.runId)) + + // Verify that the second query run hasn't reported anything yet due to lack of data + val streamingQuery2 = query2.asInstanceOf[StreamingQueryWrapper].streamingQuery + var latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + assert(laggingStores2.isEmpty) + + // Process some more data for the second query to force lag reports + input2.addData(1, 2, 3) + query2.processAllAvailable() + Thread.sleep(500) + + // Verify that the coordinator logged the correct lagging stores for the second query + latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + + assert(laggingStores2.size == 2) + assert(laggingStores2.forall(_.storeId.partitionId <= 1)) + assert(laggingStores2.forall(_.queryRunId == query2.runId)) + } + } + + test( + "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " + + "checkpointing is disabled" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + + // Go through two batches to force two snapshot uploads. + // This would be enough to pass the version check for lagging stores. + inputData.addData(1, 2, 3) + query.processAllAvailable() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Sleep for the duration of a maintenance interval - which should be enough + // to pass the time check for lagging stores. + Thread.sleep(100) + + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that no instances are marked as lagging, even when upload messages are sent. + // Since snapshot uploads are tied to commit, the lack of version difference should prevent + // the stores from being marked as lagging. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + + test("SPARK-51358: Snapshot lag reports properly detects when all state stores are lagging") { + withCoordinatorAndSQLConf( + sc, + // Only use two partitions with the faulty store provider (both stores will skip uploads) + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + + // Go through several rounds of input to force snapshot uploads + (0 until 3).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that all instances are marked as lagging, since no upload messages are being sent + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).size == 2) + query.stop() + } + } +} + +class StateStoreCoordinatorStreamingSuite extends StreamTest { + import testImplicits._ + + Seq( + ("RocksDB", classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName), + ("HDFS", classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName) + ).foreach { case (providerName, providerClassName) => + test( + s"SPARK-51358: Restarting queries do not mark state stores as lagging for $providerName" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + // Keep track of state checkpoint directory for the second run + var stateCheckpoint = "" + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + stateCheckpoint = query.lastExecution.checkpointLocation + val latestVersion = query.lastProgress.batchId + 1 + + // Verify the coordinator logged snapshot uploads + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert( + coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 + ) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // Verify that the normal state store (partitionId=2) is not lagging behind, + // and the faulty stores are reported as lagging. + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream, + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + + // Verify we evicted the previous latest uploaded snapshots from the coordinator + (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } + // Verify that we are not reporting any lagging stores after eviction, + // since none of these state stores are active anymore. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + } + ) + // Restart the query, but do not add too much data so that we don't associate + // the current StateStoreProviderId (store id + query run id) with any new uploads. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Perform one round of data, which is enough to activate instances and force a + // lagging instance report, but not enough to trigger a snapshot upload yet. + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that the state stores have restored their snapshot version from the + // checkpoint and reported their current version + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = + coordRef.getLatestSnapshotVersionForTesting(providerId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are still lagging and didn't upload + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(latestSnapshotVersion.get > 0) + } + } + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) + // Verify that we're reporting the faulty state stores (partitionId 0 and 1) + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) + } + } + } + } + + test("SPARK-51358: Restarting queries with updated SQLConf get propagated to the coordinator") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process multiple batches so that the coordinator can start reporting lagging instances + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) + // Verify that only the faulty stores are reported as lagging + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream + ) + // Bump up version multiplier, which would stop the coordinator from reporting + // lagging stores for the next few versions + spark.conf + .set(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key, "10") + // Restart the query, and verify the conf change reflects in the coordinator + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process the same amount of data as the first run + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Sleep the same amount to mimic conditions from first run + Thread.sleep(5 * 100) + // Verify that we are not reporting any lagging stores despite restarting + // because of the higher version multiplier + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + }, + StopStream + ) + } + } + } + + Seq( + ("RocksDB", classOf[RocksDBStateStoreProvider].getName), + ("HDFS", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { case (providerName, providerClassName) => + test( + s"SPARK-51358: Infrequent maintenance with $providerName using Trigger.AvailableNow " + + s"should be reported" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + // Populate state stores with an initial snapshot, so that timestamp isn't marked + // as the default 0ms. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable() + ) + // Increase maintenance interval to a much larger value to stop snapshot uploads + spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") + // Execute a few batches in a short span + testStream(query)( + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + // Verify the query ran with the AvailableNow trigger + assert(query.lastExecution.isTerminatingTrigger) + }, + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + }, + // Start without available now, otherwise the stream closes too quickly for the + // testing RPC call to report lagging state stores + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process data to activate state stores, but not enough to trigger snapshot uploads + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that all faulty stores are reported as lagging despite the short burst. + // This test scenario mimics cases where snapshots have not been uploaded for + // a while due to the short running duration of AvailableNow. + val laggingStores = coordRef.getLaggingStoresForTesting( + query.runId, + latestVersion, + isTerminatingTrigger = true + ) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) + } + } + } + } } object StateStoreCoordinatorSuite { def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) body(coordinatorRef) } finally { if (coordinatorRef != null) coordinatorRef.stop() } } + + def withCoordinatorAndSQLConf(sc: SparkContext, pairs: (String, String)*)( + body: (StateStoreCoordinatorRef, SparkSession) => Unit): Unit = { + var spark: SparkSession = null + var coordinatorRef: StateStoreCoordinatorRef = null + try { + spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + coordinatorRef = spark.streams.stateStoreCoordinator + // Set up SQLConf entries + pairs.foreach { case (key, value) => spark.conf.set(key, value) } + body(coordinatorRef, spark) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + // Unset all custom SQLConf entries + if (spark != null) pairs.foreach { case (key, _) => spark.conf.unset(key) } + if (coordinatorRef != null) coordinatorRef.stop() + StateStore.stop() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8af42d6dec26..093e8b991cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -427,6 +427,7 @@ abstract class StateVariableSuiteBase extends SharedSparkSession before { StateStore.stop() require(!StateStore.isMaintenanceRunning) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org