This is an automated email from the ASF dual-hosted git repository.

anishshri-db pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c19b79faa5fe [SPARK-56970][SS] Split CommitMetadata into 
CommitMetadataBase + V1/V2 case classes
c19b79faa5fe is described below

commit c19b79faa5fec3d21a7fdc83e9919782a567e23f
Author: ericm-db <[email protected]>
AuthorDate: Wed Jun 3 09:18:03 2026 -0700

    [SPARK-56970][SS] Split CommitMetadata into CommitMetadataBase + V1/V2 case 
classes
    
    ### What changes were proposed in this pull request?
    
    Refactor `CommitLog` so that the commit log metadata is dispatched through 
a `CommitMetadataBase` trait with concrete `CommitMetadata` (V1, watermark 
only) and `CommitMetadataV2` (watermark + `stateUniqueIds`) case classes. The 
deserializer now reads the wire-format version from the file header and 
constructs the matching subclass.
    
    This is preparation for `CommitMetadataV3` (which adds sink metadata for 
streaming sink evolution) in a follow-up PR.
    
    Notable changes:
    - Add `CommitMetadataBase` trait and `CommitMetadataV2` case class.
    - `CommitMetadata` becomes V1 (no `stateUniqueIds` field).
    - Add `CommitLog.createMetadata` factory that dispatches by version and 
defaults to the configured `STATE_STORE_CHECKPOINT_FORMAT_VERSION`.
    - `CommitLog.readCommitMetadata` reads the version line and constructs the 
matching subclass.
    - `MicroBatchExecution`, `OfflineStateRepartitionRunner`, and existing 
tests updated to use the new types/factory.
    
    This PR is the first follow-up in the SPARK-56719 sink-evolution series. 
The next two follow-ups are stacked on top of this branch (SPARK-56971: add 
`CommitMetadataV3` + `SinkMetadataInfo`; SPARK-56972: wire sink name 
persistence through `MicroBatchExecution`).
    
    ### Why are the changes needed?
    
    The pre-refactor `CommitMetadata` carried both the V1 and V2 wire shape in 
a single case class, with `stateUniqueIds` optional. That made it awkward to 
add a V3 wire format with additional fields, and forced `serialize` to take the 
wire version from `SQLConf` rather than from the metadata itself.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No new public API. The wire format for V1 changes slightly: V1 commit log 
files no longer serialize `stateUniqueIds: null`. Old V1 files continue to be 
read because the V1 deserializer ignores the (now-unknown) field.
    
    This PR also relaxes the version-exact-match check on read so that a commit 
log opened with the V2 conf can deserialize a V1 file. This incidentally 
resolves SPARK-50653.
    
    ### How was this patch tested?
    
    - Existing `CommitLogSuite` (V1, V2, and cross-version) passes; the 
cross-version test now asserts successful V1 deserialization.
    - `StreamingSinkEvolutionSuite` (from SPARK-56719) still passes.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (claude-opus-4-7)
    
    Closes #56018 from ericm-db/sink-evolution-commit-log-v3.
    
    Lead-authored-by: ericm-db <[email protected]>
    Co-authored-by: Eric Marnadi <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../streaming/checkpointing/AsyncCommitLog.scala   |   4 +-
 .../streaming/checkpointing/CommitLog.scala        | 131 +++++++++++++++++----
 .../streaming/runtime/MicroBatchExecution.scala    |   6 +-
 .../state/OfflineStateRepartitionRunner.scala      |   4 +-
 .../execution/streaming/state/StateRewriter.scala  |  34 ++----
 .../state/StateDataSourceChangeDataReadSuite.scala |   6 +-
 .../v2/state/StateDataSourceReadSuite.scala        |  77 +++++++-----
 ...tatePartitionAllColumnFamiliesWriterSuite.scala |   2 +-
 .../spark/sql/streaming/CommitLogSuite.scala       |  53 ++++++---
 9 files changed, 222 insertions(+), 95 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala
index 7a6c26b249e9..d13affb82dbb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncCommitLog.scala
@@ -53,7 +53,7 @@ class AsyncCommitLog(
    *         the async write of the batch is completed.  Future may also be 
completed exceptionally
    *         to indicate some write error.
    */
-  def addAsync(batchId: Long, metadata: CommitMetadata): 
CompletableFuture[Long] = {
+  def addAsync(batchId: Long, metadata: CommitMetadataBase): 
CompletableFuture[Long] = {
     require(metadata != null, "'null' metadata cannot be written to a metadata 
log")
     val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { 
output =>
       serialize(metadata, output)
@@ -77,7 +77,7 @@ class AsyncCommitLog(
    * @param metadata metadata of batch to write
    * @return true if operation is successful otherwise false.
    */
-  def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = {
+  def addInMemory(batchId: Long, metadata: CommitMetadataBase): Boolean = {
     if (batchCache.containsKey(batchId)) {
       false
     } else {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala
index b73020b6060c..820aecf70d0e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/CommitLog.scala
@@ -26,6 +26,7 @@ import org.json4s.{Formats, NoTypeHints}
 import org.json4s.jackson.Serialization
 
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -50,39 +51,119 @@ class CommitLog(
     sparkSession: SparkSession,
     path: String,
     readOnly: Boolean = false)
-  extends HDFSMetadataLog[CommitMetadata](sparkSession, path, readOnly) {
+  extends HDFSMetadataLog[CommitMetadataBase](sparkSession, path, readOnly) {
 
   import CommitLog._
 
-  private val VERSION: Int = sparkSession.conf.get(
+  // The configured commit log format version. Used as the default version 
when callers
+  // construct metadata through [[createMetadata]].
+  private[sql] val defaultVersion: Int = sparkSession.conf.get(
     SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt
 
-  override protected[sql] def deserialize(in: InputStream): CommitMetadata = {
-    // called inside a try-finally where the underlying stream is closed in 
the caller
-    val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
-    if (!lines.hasNext) {
-      throw new IllegalStateException("Incomplete log file in the offset 
commit log")
-    }
-    // TODO [SPARK-49462] This validation should be relaxed for a stateless 
query.
-    // TODO [SPARK-50653] This validation should be relaxed to support reading
-    //  a V1 log file when VERSION is V2
-    validateVersionExactMatch(lines.next().trim, VERSION)
-    val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
-    CommitMetadata(metadataJson)
+  override protected[sql] def deserialize(in: InputStream): CommitMetadataBase 
= {
+    CommitLog.readCommitMetadata(in)
   }
 
-  override protected[sql] def serialize(metadata: CommitMetadata, out: 
OutputStream): Unit = {
+  override protected[sql] def serialize(metadata: CommitMetadataBase, out: 
OutputStream): Unit = {
     // called inside a try-finally where the underlying stream is closed in 
the caller
-    out.write(s"v${VERSION}".getBytes(UTF_8))
+    out.write(s"v${metadata.version}".getBytes(UTF_8))
     out.write('\n')
 
     // write metadata
     out.write(metadata.json.getBytes(UTF_8))
   }
+
+  /**
+   * Factory for creating a [[CommitMetadataBase]] for the requested wire 
format version.
+   * Defaults to the version configured via 
[[SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION]].
+   */
+  def createMetadata(
+      nextBatchWatermarkMs: Long = 0,
+      stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None,
+      commitLogFormatVersion: Int = defaultVersion): CommitMetadataBase = {
+    commitLogFormatVersion match {
+      case VERSION_2 =>
+        CommitMetadataV2(nextBatchWatermarkMs, stateUniqueIds)
+      case VERSION_1 =>
+        // VERSION_1 cannot persist stateUniqueIds; withStateUniqueIds 
enforces this invariant
+        // (it throws if stateUniqueIds is non-empty).
+        CommitMetadata(nextBatchWatermarkMs).withStateUniqueIds(stateUniqueIds)
+      case v =>
+        throw QueryExecutionErrors.logVersionGreaterThanSupported(v, 
CommitLog.MAX_VERSION)
+    }
+  }
 }
 
 object CommitLog {
   private val EMPTY_JSON = "{}"
+  val VERSION_1 = 1
+  val VERSION_2 = 2
+  val MAX_VERSION: Int = VERSION_2
+
+  /**
+   * Reads a single commit log entry and dispatches to the matching
+   * [[CommitMetadataBase]] subclass based on the wire format version recorded 
in the file.
+   */
+  private[spark] def readCommitMetadata(in: InputStream): CommitMetadataBase = 
{
+    val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
+    if (!lines.hasNext) {
+      throw new IllegalStateException("Incomplete log file in the offset 
commit log")
+    }
+    val version = MetadataVersionUtil.validateVersion(lines.next().trim, 
MAX_VERSION)
+    val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
+    version match {
+      case VERSION_2 => CommitMetadataV2(metadataJson)
+      case VERSION_1 => CommitMetadata(metadataJson)
+      case v => throw QueryExecutionErrors.logVersionGreaterThanSupported(v, 
MAX_VERSION)
+    }
+  }
+}
+
+/**
+ * Base trait for commit log metadata. Concrete subclasses correspond to wire 
format versions
+ * and override [[version]] accordingly.
+ */
+trait CommitMetadataBase extends Serializable {
+  def version: Int
+  def nextBatchWatermarkMs: Long
+  def stateUniqueIds: Option[Map[Long, Array[Array[String]]]]
+
+  /**
+   * Returns a copy of this metadata with the given state store unique ids, 
preserving the
+   * concrete subclass and all of its other fields. Deriving a new commit from 
an existing one
+   * should go through this method (rather than reconstructing via 
[[CommitLog.createMetadata]])
+   * so that version-specific fields are not silently dropped when new 
metadata versions are
+   * introduced.
+   */
+  def withStateUniqueIds(
+      stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): 
CommitMetadataBase
+
+  def json: String = Serialization.write(this)(CommitMetadata.format)
+}
+
+/**
+ * Commit log metadata for [[CommitLog.VERSION_1]]. Records the watermark for 
the next batch only.
+ *
+ * @param nextBatchWatermarkMs The watermark of the next batch.
+ */
+case class CommitMetadata(
+    nextBatchWatermarkMs: Long = 0) extends CommitMetadataBase {
+  override def version: Int = CommitLog.VERSION_1
+  override def stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None
+
+  override def withStateUniqueIds(
+      stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): CommitMetadata 
= {
+    require(stateUniqueIds.forall(_.isEmpty),
+      s"stateUniqueIds cannot be set for commit log format version 
${CommitLog.VERSION_1}; " +
+        s"use version ${CommitLog.VERSION_2} to persist state store checkpoint 
ids.")
+    this
+  }
+}
+
+object CommitMetadata {
+  implicit val format: Formats = Serialization.formats(NoTypeHints)
+
+  def apply(json: String): CommitMetadata = 
Serialization.read[CommitMetadata](json)
 }
 
 /**
@@ -104,19 +185,23 @@ object CommitLog {
  *          +--- ......
  * In the commit log, in addition to nextBatchWatermarkMs, we also store the 
unique ids of the
  * state store files.
+ *
  * @param nextBatchWatermarkMs The watermark of the next batch.
  * @param stateUniqueIds Map[Long, Array[Array[String]]] of map
  *                       OperatorId -> (partitionID -> array of uniqueID)
  */
-
-case class CommitMetadata(
+case class CommitMetadataV2(
     nextBatchWatermarkMs: Long = 0,
-    stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) {
-  def json: String = Serialization.write(this)(CommitMetadata.format)
+    stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) extends 
CommitMetadataBase {
+  override def version: Int = CommitLog.VERSION_2
+
+  override def withStateUniqueIds(
+      stateUniqueIds: Option[Map[Long, Array[Array[String]]]]): 
CommitMetadataV2 =
+    copy(stateUniqueIds = stateUniqueIds)
 }
 
-object CommitMetadata {
-  implicit val format: Formats = Serialization.formats(NoTypeHints)
+object CommitMetadataV2 {
+  import CommitMetadata.format
 
-  def apply(json: String): CommitMetadata = 
Serialization.read[CommitMetadata](json)
+  def apply(json: String): CommitMetadataV2 = 
Serialization.read[CommitMetadataV2](json)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
index 68914913a00e..8fc6c718c5a0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.execution.{SparkPlan, 
SQLExecution}
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
RealTimeStreamScanExec, StreamingDataSourceV2Relation, 
StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, 
WriteToDataSourceV2Exec}
 import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, 
OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, 
Sink, Source, StreamingQueryPlanTraverseHelper}
-import 
org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, 
CheckpointVersionManager, CommitMetadata, OffsetLogType, OffsetSeqBase, 
OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataV2}
+import 
org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, 
CheckpointVersionManager, OffsetLogType, OffsetSeqBase, OffsetSeqLog, 
OffsetSeqMetadata, OffsetSeqMetadataV2}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
 StatefulOpStateStoreCheckpointInfo, StateStoreWriter}
 import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS,
 DIR_NAME_OFFSETS, DIR_NAME_STATE}
 import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, 
WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
@@ -1465,7 +1465,9 @@ class MicroBatchExecution(
         None
       }
       if (!commitLog.add(execCtx.batchId,
-        CommitMetadata(watermarkTracker.currentWatermark, stateStoreCkptId))) {
+        commitLog.createMetadata(
+          nextBatchWatermarkMs = watermarkTracker.currentWatermark,
+          stateUniqueIds = stateStoreCkptId))) {
         throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
       }
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
index 1491d2698906..dc13fa1030a0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
@@ -294,7 +294,9 @@ class OfflineStateRepartitionRunner(
       lastCommittedBatchId: Long,
       opIdToStateStoreCkptInfo: Option[Map[Long, Array[Array[String]]]]): Unit 
= {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-    val commitMetadata = latestCommit.copy(stateUniqueIds = 
opIdToStateStoreCkptInfo)
+    // Derive the new commit from the latest one so version-specific fields 
are preserved and the
+    // wire format version stays consistent with the source checkpoint.
+    val commitMetadata = 
latestCommit.withStateUniqueIds(opIdToStateStoreCkptInfo)
 
     if (!checkpointMetadata.commitLog.add(newBatchId, commitMetadata)) {
       throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
index fd890161caaf..546a9a601964 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
@@ -22,7 +22,7 @@ import java.util.UUID
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.{SparkIllegalStateException, SparkThrowable, 
TaskContext}
+import org.apache.spark.{SparkIllegalStateException, TaskContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys._
@@ -376,27 +376,19 @@ class StateRewriter(
   }
 
   private def verifyCheckpointFormatVersion(): Unit = {
-    // Verify checkpoint version in sqlConf based on commitLog for 
readCheckpoint
-    // in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION.
-    // Using read batch commit since the latest commit could be a skipped 
batch.
-    // If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong, 
readCheckpoint.commitLog
-    // will throw an exception, and we will propagate this exception upstream.
-    // This prevents the StateRewriter from failing to write the correct state 
files
-    try {
-      readCheckpoint.commitLog.get(readBatchId)
-    } catch {
-        case e: IllegalStateException if e.getCause != null &&
-            e.getCause.isInstanceOf[SparkThrowable] =>
-          val sparkThrowable = e.getCause.asInstanceOf[SparkThrowable]
-          if (sparkThrowable.getCondition == 
"INVALID_LOG_VERSION.EXACT_MATCH_VERSION") {
-            val params = sparkThrowable.getMessageParameters
-            val expectedVersion = params.get("version")
-            val actualVersion = params.get("matchVersion")
-            throw 
StateRewriterErrors.stateCheckpointFormatVersionMismatchError(
-              checkpointLocationForRead, expectedVersion, actualVersion)
-          }
-          throw e
+    // Verify checkpoint version in sqlConf matches the version recorded in 
the read commit log,
+    // in case the user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION. 
This prevents the
+    // StateRewriter from writing state files in a format that disagrees with 
the source
+    // checkpoint. Using the read batch commit since the latest commit could 
be a skipped batch.
+    readCheckpoint.commitLog.get(readBatchId).foreach { metadata =>
+      val configuredVersion = readCheckpoint.commitLog.defaultVersion
+      if (metadata.version != configuredVersion) {
+        throw StateRewriterErrors.stateCheckpointFormatVersionMismatchError(
+          checkpointLocationForRead,
+          expectedVersion = metadata.version.toString,
+          actualVersion = configuredVersion.toString)
       }
+    }
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala
index bae78f0b4762..4e9f6cca2ffc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration
 import org.scalatest.Assertions
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
CommitMetadata}
+import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
CommitMetadata, CommitMetadataV2}
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamExecution}
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.functions.{col, window}
@@ -237,11 +237,11 @@ abstract class StateDataSourceChangeDataReaderSuite 
extends StateDataSourceTestB
         new File(tempDir.getAbsolutePath, "commits").getAbsolutePath)
 
       // Start version: treated as v1 (no operator unique ids)
-      val startMetadata = CommitMetadata(0, None)
+      val startMetadata = CommitMetadata(0)
       assert(commitLog.add(0, startMetadata))
 
       // End version: treated as v2 (operator 0 has unique ids)
-      val endMetadata = CommitMetadata(0,
+      val endMetadata = CommitMetadataV2(0,
         Some(Map[Long, Array[Array[String]]](0L -> Array(Array("uid")))))
       assert(commitLog.add(1, endMetadata))
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 2def79828fac..4a2a454077a7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -23,7 +23,7 @@ import java.util.UUID
 import org.apache.hadoop.conf.Configuration
 import org.scalatest.Assertions
 
-import org.apache.spark.{SparkException, SparkThrowable, 
SparkUnsupportedOperationException}
+import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, 
GenericInternalRow}
@@ -589,8 +589,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite 
extends StateDataSourceR
   override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
     new RocksDBStateStoreProvider
 
-  import testImplicits._
-
   override def beforeAll(): Unit = {
     super.beforeAll()
     spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
@@ -600,34 +598,57 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite 
extends StateDataSourceR
       "true")
   }
 
-  // TODO: Remove this test once we allow migrations from checkpoint v1 to v2
-  test("reading checkpoint v2 store with version 1 should fail") {
-    withTempDir { tmpDir =>
-      val inputData = MemoryStream[(Int, Long)]
-      val query = getStreamStreamJoinQuery(inputData)
-      testStream(query)(
-        StartStream(checkpointLocation = tmpDir.getCanonicalPath),
-        AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
-        ProcessAllAvailable(),
-        Execute { _ => Thread.sleep(2000) },
-        StopStream
-      )
+  // Expected state after runLargeDataStreamingAggregationQuery, read from 
batch 2 / operator 0.
+  private val expectedLargeAggregationState: Seq[Row] = Seq(
+    Row(0, 5, 60, 30, 0), Row(1, 5, 65, 31, 1), Row(2, 5, 70, 32, 2),
+    Row(3, 4, 72, 33, 3), Row(4, 4, 76, 34, 4), Row(5, 4, 80, 35, 5),
+    Row(6, 4, 84, 36, 6), Row(7, 4, 88, 37, 7), Row(8, 4, 92, 38, 8),
+    Row(9, 4, 96, 39, 9))
+
+  private def readLargeAggregationState(checkpointDir: String): DataFrame =
+    spark.read.format("statestore")
+      .option(StateSourceOptions.PATH, checkpointDir)
+      .option(StateSourceOptions.BATCH_ID, 2)
+      .option(StateSourceOptions.OPERATOR_ID, 0)
+      .load()
+      .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt",
+        "value.sum AS value_sum", "value.max AS value_max", "value.min AS 
value_min")
 
+  // SPARK-56970: The commit log wire format version is now discovered from 
the file header
+  // rather than required to match STATE_STORE_CHECKPOINT_FORMAT_VERSION. As a 
result a V1 commit
+  // log can be read under a V2-configured session (and vice versa). Note this 
only applies to the
+  // commit log layer; reading a V2 state store still requires version 2 to be 
configured because
+  // the state store files are named with checkpoint unique ids.
+  test("SPARK-56970: reading a v1 checkpoint with commit log version 2 
configured succeeds") {
+    withTempDir { tempDir =>
+      // Override the suite default to write a V1 checkpoint (no checkpoint 
unique ids).
       withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") {
-        // Verify reading state throws error when reading checkpoint v2 with 
version 1
-        val exc = intercept[IllegalStateException] {
-          val stateDf = spark.read.format("statestore")
-            .option(StateSourceOptions.BATCH_ID, 0)
-            .option(StateSourceOptions.OPERATOR_ID, 0)
-            .load(tmpDir.getCanonicalPath)
-          stateDf.collect()
-        }
+        runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+      }
+
+      // The suite default reads with version 2 configured; the V1 commit log 
must still be read.
+      checkAnswer(
+        readLargeAggregationState(tempDir.getAbsolutePath), 
expectedLargeAggregationState)
+    }
+  }
 
-        checkError(exc.getCause.asInstanceOf[SparkThrowable],
-          "INVALID_LOG_VERSION.EXACT_MATCH_VERSION", "KD002",
-          Map(
-            "version" -> "2",
-            "matchVersion" -> "1"))
+  test("SPARK-56970: reading a v2 checkpoint with commit log version 1 
configured fails on the " +
+    "state store, not the commit log") {
+    withTempDir { tempDir =>
+      // The suite configures commit log format version 2, so this writes a V2 
checkpoint whose
+      // state store files are named with checkpoint unique ids.
+      runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+      withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") {
+        // The commit log now deserializes across versions, so this no longer 
fails with
+        // INVALID_LOG_VERSION at the commit-log layer. Reading the V2 state 
store itself still
+        // requires version 2 to be configured: with version 1 the reader 
looks for non-unique
+        // state file names and cannot locate the unique-id-named files.
+        val ex = intercept[SparkException] {
+          readLargeAggregationState(tempDir.getAbsolutePath).collect()
+        }
+        assert(ex.getMessage.contains("CANNOT_LOAD_STATE_STORE") ||
+          
Option(ex.getCause).map(_.getMessage).exists(_.contains("CANNOT_LOAD_STATE_STORE")))
       }
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
index be7874e806cd..22d0af0a77fd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
@@ -99,7 +99,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
 
     // Commit to commitLog with checkpoint IDs
     val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
-    val commitMetadata = latestCommit.copy(stateUniqueIds = checkpointInfos)
+    val commitMetadata = latestCommit.withStateUniqueIds(checkpointInfos)
     targetCheckpointMetadata.commitLog.add(writeBatchId, commitMetadata)
     val versionToCheck = writeBatchId + 1
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala
index 332de78e7cbf..e4becd057168 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/CommitLogSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import java.io.{ByteArrayInputStream, FileInputStream, FileOutputStream}
 import java.nio.file.Path
 
-import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
CommitMetadata}
+import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
CommitMetadata, CommitMetadataBase, CommitMetadataV2}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 
@@ -62,7 +62,7 @@ class CommitLogSuite extends SharedSparkSession {
     )
   }
 
-  private def testSerde(commitMetadata: CommitMetadata, path: Path): Unit = {
+  private def testSerde(commitMetadata: CommitMetadataBase, path: Path): Unit 
= {
     if (regenerateGoldenFiles) {
       val commitLog = new CommitLog(spark, path.toString)
       val outputStream = new 
FileOutputStream(path.resolve("testCommitLog").toFile)
@@ -102,19 +102,21 @@ class CommitLogSuite extends SharedSparkSession {
           0L -> Array(Array("unique_id1", "unique_id2"), Array("unique_id3", 
"unique_id4")),
             1L -> Array(Array("unique_id5", "unique_id6"), Array("unique_id7", 
"unique_id8"))
         )
-      val testMetadataV2 = CommitMetadata(0, Some(testStateUniqueIds))
+      val testMetadataV2 = CommitMetadataV2(0, Some(testStateUniqueIds))
       testSerde(testMetadataV2, testCommitLogV2FilePath)
     }
   }
 
   test("Basic Commit Log V2 SerDe - empty stateUniqueIds") {
     withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
-      val testMetadataV2 = CommitMetadata(0, Some(Map[Long, 
Array[Array[String]]]()))
+      val testMetadataV2 = CommitMetadataV2(0, Some(Map[Long, 
Array[Array[String]]]()))
       testSerde(testMetadataV2, testCommitLogV2FilePathEmptyUniqueId)
     }
   }
 
-  // Old metadata structure with no state unique ids should not affect the 
deserialization
+  // SPARK-50653: When the configured commit log version is V2, a V1 file on 
disk should still
+  // deserialize successfully into a V1 [[CommitMetadata]] because the wire 
format version is now
+  // discovered from the file header rather than enforced to match the conf.
   test("Cross-version V1 SerDe") {
     withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
       val commitlogV1 = """v1
@@ -122,18 +124,41 @@ class CommitLogSuite extends SharedSparkSession {
       val inputStream: ByteArrayInputStream =
         new ByteArrayInputStream(commitlogV1.getBytes("UTF-8"))
 
-      // TODO [SPARK-50653]: Uncomment the below when v2 -> v1 backward 
compatibility is added
-      // val commitMetadata: CommitMetadata = new CommitLog(
-      // spark, testCommitLogV1FilePath.toString).deserialize(inputStream)
-      // assert(commitMetadata.nextBatchWatermarkMs === 233)
-      // assert(commitMetadata.stateUniqueIds === Map.empty)
+      val commitMetadata = new CommitLog(
+        spark, testCommitLogV1FilePath.toString).deserialize(inputStream)
+      assert(commitMetadata.version === CommitLog.VERSION_1)
+      assert(commitMetadata.nextBatchWatermarkMs === 233)
+      assert(commitMetadata.stateUniqueIds.isEmpty)
+    }
+  }
+
+  test("SPARK-56970: creating a V1 commit with stateUniqueIds should fail") {
+    withTempDir { tmpDir =>
+      val commitLog = new CommitLog(spark, tmpDir.getCanonicalPath)
+      val stateUniqueIds: Map[Long, Array[Array[String]]] =
+        Map(0L -> Array(Array("unique_id1", "unique_id2")))
+
+      // Through the createMetadata factory with an explicit V1 format version.
+      val e1 = intercept[IllegalArgumentException] {
+        commitLog.createMetadata(
+          nextBatchWatermarkMs = 1,
+          stateUniqueIds = Some(stateUniqueIds),
+          commitLogFormatVersion = CommitLog.VERSION_1)
+      }
+      assert(e1.getMessage.contains("stateUniqueIds cannot be set"))
 
-      // TODO [SPARK-50653]: remove the below when v2 -> v1 backward 
compatibility is added
-      val e = intercept[IllegalStateException] {
-        new CommitLog(spark, 
testCommitLogV1FilePath.toString).deserialize(inputStream)
+      // Directly through withStateUniqueIds on a V1 metadata.
+      val e2 = intercept[IllegalArgumentException] {
+        CommitMetadata(1).withStateUniqueIds(Some(stateUniqueIds))
       }
+      assert(e2.getMessage.contains("stateUniqueIds cannot be set"))
 
-      assert (e.getMessage.contains("only supported log version"))
+      // None and an empty map are allowed for V1 (no unique ids to persist).
+      assert(CommitMetadata(1).withStateUniqueIds(None).stateUniqueIds.isEmpty)
+      assert(commitLog.createMetadata(
+        nextBatchWatermarkMs = 1,
+        stateUniqueIds = Some(Map.empty[Long, Array[Array[String]]]),
+        commitLogFormatVersion = CommitLog.VERSION_1).version === 
CommitLog.VERSION_1)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to