This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2871905e9ba8 [SPARK-54924][SS] State Rewriter to read state, transform 
it and write new state
2871905e9ba8 is described below

commit 2871905e9ba824f0c1c3ae2397334e838fa92faf
Author: micheal-o <[email protected]>
AuthorDate: Wed Jan 7 12:52:24 2026 -0800

    [SPARK-54924][SS] State Rewriter to read state, transform it and write new 
state
    
    ### What changes were proposed in this pull request?
    
    Introduce State Rewriter to rewrite the state stores for a stateful 
streaming query. Read state, transform it, then write the new state.
    
    ### Why are the changes needed?
    
    For offline state repartitioning and other future use case
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Updated existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, Claude-4.5-opus
    
    Closes #53703 from micheal-o/state_rewriter.
    
    Authored-by: micheal-o <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  26 ++
 .../streaming/state/OperatorStateMetadata.scala    |   9 +
 .../streaming/state/StatePartitionWriter.scala     |   3 +-
 .../execution/streaming/state/StateRewriter.scala  | 404 +++++++++++++++++++++
 .../v2/state/StateDataSourceTestBase.scala         |   2 +-
 ...tatePartitionAllColumnFamiliesWriterSuite.scala | 324 +++++------------
 6 files changed, 542 insertions(+), 226 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 18ef49f01bef..1f9c7321b50f 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5587,6 +5587,32 @@
     },
     "sqlState" : "42616"
   },
+  "STATE_REWRITER_INVALID_CHECKPOINT" : {
+    "message" : [
+      "The state rewrite checkpoint location '<checkpointLocation>' is in an 
invalid state."
+    ],
+    "subClass" : {
+      "MISSING_KEY_ENCODER_SPEC" : {
+        "message" : [
+          "Key state encoder spec is expected for column family 
'<colFamilyName>' but was not found.",
+          "This is likely a bug, please report it."
+        ]
+      },
+      "MISSING_OPERATOR_METADATA" : {
+        "message" : [
+          "No stateful operator metadata was found for batch <batchId>.",
+          "Ensure that the checkpoint is for a stateful streaming query and 
the query ran on a Spark version that supports operator metadata (Spark 4.0+)."
+        ]
+      },
+      "UNSUPPORTED_STATE_STORE_METADATA_VERSION" : {
+        "message" : [
+          "Unsupported state store metadata version encountered.",
+          "Only StateStoreMetadataV1 and StateStoreMetadataV2 are supported."
+        ]
+      }
+    },
+    "sqlState" : "55019"
+  },
   "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
     "message" : [
       "Failed to create column family with unsupported starting character and 
name=<colFamilyName>."
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
index c34545216fda..6b2295da03b9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import java.io.{BufferedReader, InputStreamReader}
 import java.nio.charset.StandardCharsets
 
+import scala.collection.immutable.ArraySeq
 import scala.reflect.ClassTag
 
 import org.apache.hadoop.conf.Configuration
@@ -80,12 +81,17 @@ trait OperatorStateMetadata {
   def version: Int
 
   def operatorInfo: OperatorInfo
+
+  def stateStoresMetadata: Seq[StateStoreMetadata]
 }
 
 case class OperatorStateMetadataV1(
     operatorInfo: OperatorInfoV1,
     stateStoreInfo: Array[StateStoreMetadataV1]) extends OperatorStateMetadata 
{
   override def version: Int = 1
+
+  override def stateStoresMetadata: Seq[StateStoreMetadata] =
+    ArraySeq.unsafeWrapArray(stateStoreInfo)
 }
 
 case class OperatorStateMetadataV2(
@@ -93,6 +99,9 @@ case class OperatorStateMetadataV2(
     stateStoreInfo: Array[StateStoreMetadataV2],
     operatorPropertiesJson: String) extends OperatorStateMetadata {
   override def version: Int = 2
+
+  override def stateStoresMetadata: Seq[StateStoreMetadata] =
+    ArraySeq.unsafeWrapArray(stateStoreInfo)
 }
 
 object OperatorStateMetadataUtils extends Logging {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
index aac13d3f69f9..3df97d3adc0e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
@@ -51,7 +51,7 @@ class StatePartitionAllColumnFamiliesWriter(
     hadoopConf: Configuration,
     partitionId: Int,
     targetCpLocation: String,
-    operatorId: Int,
+    operatorId: Long,
     storeName: String,
     currentBatchId: Long,
     colFamilyToWriterInfoMap: Map[String, 
StatePartitionWriterColumnFamilyInfo],
@@ -153,6 +153,7 @@ class StatePartitionAllColumnFamiliesWriter(
       if (!stateStore.hasCommitted) {
         stateStore.abort()
       }
+      provider.close()
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
new file mode 100644
index 000000000000..da28a3c907f7
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.{SparkIllegalStateException, TaskContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
+import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
+import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
 TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
+import 
org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants, 
StreamingQueryCheckpointMetadata}
+import 
org.apache.spark.sql.execution.streaming.state.{StatePartitionAllColumnFamiliesWriter,
 StateSchemaCompatibilityChecker}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * State Rewriter is used to rewrite the state stores for a stateful streaming 
query.
+ * It reads state from a checkpoint location, optionally applies 
transformation to the state,
+ * and then writes the state back to a (possibly different) checkpoint 
location for a new batch ID.
+ *
+ * Example use case is for offline state repartitioning.
+ * Can also be used to support other use cases.
+ *
+ * @param sparkSession The active Spark session.
+ * @param readBatchId The batch ID for reading state.
+ * @param writeBatchId The batch ID to which the (transformed) state will be 
written.
+ * @param resolvedCheckpointLocation The resolved checkpoint path where state 
will be written.
+ * @param hadoopConf Hadoop configuration for file system operations.
+ * @param readResolvedCheckpointLocation Optional separate checkpoint location 
to read state from.
+ *                                       If None, reads from 
resolvedCheckpointLocation.
+ * @param transformFunc Optional transformation function applied to each 
operator's state
+ *                      DataFrame. If None, state is written as-is.
+ * @param writeCheckpointMetadata Optional checkpoint metadata for the 
resolvedCheckpointLocation.
+ *                                If None, will create a new one for 
resolvedCheckpointLocation.
+ *                                Helps us to reuse already cached checkpoint 
log entries,
+ *                                instead of starting from scratch.
+ */
+class StateRewriter(
+    sparkSession: SparkSession,
+    readBatchId: Long,
+    writeBatchId: Long,
+    resolvedCheckpointLocation: String,
+    hadoopConf: Configuration,
+    readResolvedCheckpointLocation: Option[String] = None,
+    transformFunc: Option[DataFrame => DataFrame] = None,
+    writeCheckpointMetadata: Option[StreamingQueryCheckpointMetadata] = None
+) extends Logging {
+  require(readResolvedCheckpointLocation.isDefined || readBatchId < 
writeBatchId,
+    s"Read batch id $readBatchId must be less than write batch id 
$writeBatchId " +
+      "when reading and writing to the same checkpoint location")
+
+  // If a different location was specified for reading state, use it.
+  // Else, use same location for reading and writing state.
+  private val checkpointLocationForRead =
+    readResolvedCheckpointLocation.getOrElse(resolvedCheckpointLocation)
+  private val stateRootLocation = new Path(
+    resolvedCheckpointLocation, 
StreamingCheckpointConstants.DIR_NAME_STATE).toString
+
+  def run(): Unit = {
+    logInfo(log"Starting state rewrite for " +
+      log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}, " +
+      log"readCheckpointLocation=" +
+      log"${MDC(CHECKPOINT_LOCATION, 
readResolvedCheckpointLocation.getOrElse(""))}, " +
+      log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
+      log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
+
+    val (_, timeTakenMs) = Utils.timeTakenMs {
+      runInternal()
+    }
+
+    logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms 
for " +
+      log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}")
+  }
+
+  private def runInternal(): Unit = {
+    try {
+      val stateMetadataReader = new StateMetadataPartitionReader(
+        resolvedCheckpointLocation,
+        new SerializableConfiguration(hadoopConf),
+        readBatchId)
+
+      val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata
+      if (allOperatorsMetadata.isEmpty) {
+        // Its possible that the query is stateless
+        // or ran on older spark version without op metadata
+        throw StateRewriterErrors.missingOperatorMetadataError(
+          resolvedCheckpointLocation, readBatchId)
+      }
+
+      // Use the same conf in the offset log to create the store conf,
+      // to make sure the state is written with the right conf.
+      val (storeConf, sqlConf) = createConfsFromOffsetLog()
+      // SQLConf doesn't serialize properly (reader becomes null), so extract 
as Map
+      val sqlConfEntries: Map[String, String] = sqlConf.getAllConfs
+
+      // A Hadoop Configuration can be about 10 KB, which is pretty big, so 
broadcast it
+      val hadoopConfBroadcast =
+        SerializableConfiguration.broadcast(sparkSession.sparkContext, 
hadoopConf)
+
+      // Do rewrite for each operator
+      // We can potentially parallelize this, but for now, do sequentially
+      allOperatorsMetadata.foreach { opMetadata =>
+        val stateStoresMetadata = opMetadata.stateStoresMetadata
+        assert(!stateStoresMetadata.isEmpty,
+          s"Operator ${opMetadata.operatorInfo.operatorName} has no state 
stores")
+
+        val storeToSchemaFilesMap = getStoreToSchemaFilesMap(opMetadata)
+        val stateVarsIfTws = getStateVariablesIfTWS(opMetadata)
+
+        // Rewrite each state store of the operator
+        stateStoresMetadata.foreach { stateStoreMetadata =>
+          rewriteStore(
+            opMetadata,
+            stateStoreMetadata,
+            storeConf,
+            hadoopConfBroadcast,
+            storeToSchemaFilesMap(stateStoreMetadata.storeName),
+            stateVarsIfTws,
+            sqlConfEntries
+          )
+        }
+      }
+    } catch {
+      case e: Throwable =>
+        logError(log"State rewrite failed for " +
+          log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}, " +
+          log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
+          log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}", e)
+        throw e
+    }
+  }
+
+  private def rewriteStore(
+      opMetadata: OperatorStateMetadata,
+      stateStoreMetadata: StateStoreMetadata,
+      storeConf: StateStoreConf,
+      hadoopConfBroadcast: Broadcast[SerializableConfiguration],
+      storeSchemaFiles: List[Path],
+      stateVarsIfTws: Map[String, TransformWithStateVariableInfo],
+      sqlConfEntries: Map[String, String]
+  ): Unit = {
+    // Read state
+    val stateDf = sparkSession.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, checkpointLocationForRead)
+      .option(StateSourceOptions.BATCH_ID, readBatchId)
+      .option(StateSourceOptions.OPERATOR_ID, 
opMetadata.operatorInfo.operatorId)
+      .option(StateSourceOptions.STORE_NAME, stateStoreMetadata.storeName)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+      .load()
+
+    // Run the caller state transformation func if provided
+    // Otherwise, use the state as is
+    val updatedStateDf = transformFunc.map(func => 
func(stateDf)).getOrElse(stateDf)
+    require(updatedStateDf.schema == stateDf.schema,
+      s"State transformation function must return a DataFrame with the same 
schema " +
+        s"as the original state DataFrame. Original schema: ${stateDf.schema}, 
" +
+        s"Updated schema: ${updatedStateDf.schema}")
+
+    val schemaProvider = createStoreSchemaProviderIfTWS(
+      opMetadata.operatorInfo.operatorName,
+      storeSchemaFiles
+    )
+    val writerColFamilyInfoMap = getWriterColFamilyInfoMap(
+      opMetadata.operatorInfo.operatorId,
+      stateStoreMetadata,
+      storeSchemaFiles,
+      stateVarsIfTws
+    )
+
+    logInfo(log"Writing new state for " +
+      log"operator=${MDC(OP_TYPE, opMetadata.operatorInfo.operatorName)}, " +
+      log"stateStore=${MDC(STATE_NAME, stateStoreMetadata.storeName)}, " +
+      log"numColumnFamilies=${MDC(COUNT, writerColFamilyInfoMap.size)}, " +
+      log"numSchemaFiles=${MDC(NUM_FILES, storeSchemaFiles.size)}, " +
+      log"for new batch=${MDC(BATCH_ID, writeBatchId)}, " +
+      log"for checkpoint=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}")
+
+    // Write state for each partition on the executor.
+    // Setting this as local val,
+    // to avoid serializing the entire Rewriter object per partition.
+    val targetCheckpointLocation = resolvedCheckpointLocation
+    val currentBatchId = writeBatchId
+    updatedStateDf.queryExecution.toRdd.foreachPartition { partitionIter =>
+      // Recreate SQLConf on executor from serialized entries
+      val executorSqlConf = new SQLConf()
+      sqlConfEntries.foreach { case (k, v) => executorSqlConf.setConfString(k, 
v) }
+
+      val partitionWriter = new StatePartitionAllColumnFamiliesWriter(
+        storeConf,
+        hadoopConfBroadcast.value.value,
+        TaskContext.get().partitionId(),
+        targetCheckpointLocation,
+        opMetadata.operatorInfo.operatorId,
+        stateStoreMetadata.storeName,
+        currentBatchId,
+        writerColFamilyInfoMap,
+        opMetadata.operatorInfo.operatorName,
+        schemaProvider,
+        executorSqlConf
+      )
+
+      partitionWriter.write(partitionIter)
+    }
+  }
+
+  /** Create the store and sql confs from the conf written in the offset log */
+  private def createConfsFromOffsetLog(): (StateStoreConf, SQLConf) = {
+    val offsetLog = writeCheckpointMetadata.getOrElse(
+      new StreamingQueryCheckpointMetadata(sparkSession, 
resolvedCheckpointLocation)).offsetLog
+
+    // We want to use the same confs written in the offset log for the new 
batch
+    val offsetSeq = offsetLog.get(writeBatchId)
+    require(offsetSeq.isDefined, s"Offset seq must be present for the new 
batch $writeBatchId")
+    val metadata = offsetSeq.get.metadataOpt
+    require(metadata.isDefined, s"Metadata must be present for the new batch 
$writeBatchId")
+
+    val clonedSqlConf = sparkSession.sessionState.conf.clone()
+    OffsetSeqMetadata.setSessionConf(metadata.get, clonedSqlConf)
+    (StateStoreConf(clonedSqlConf), clonedSqlConf)
+  }
+
+  /** Get the map of state store name to schema files, for an operator */
+  private def getStoreToSchemaFilesMap(
+      opMetadata: OperatorStateMetadata): Map[String, List[Path]] = {
+    opMetadata.stateStoresMetadata.map { storeMetadata =>
+      val schemaFiles = storeMetadata match {
+        // No schema files for v1. It has a fixed/known schema file path
+        case _: StateStoreMetadataV1 => List.empty[Path]
+        case v2: StateStoreMetadataV2 => v2.stateSchemaFilePaths.map(new 
Path(_))
+        case _ =>
+          throw StateRewriterErrors.unsupportedStateStoreMetadataVersionError(
+            resolvedCheckpointLocation)
+      }
+      storeMetadata.storeName -> schemaFiles
+    }.toMap
+  }
+
+  private def getWriterColFamilyInfoMap(
+      operatorId: Long,
+      storeMetadata: StateStoreMetadata,
+      schemaFiles: List[Path],
+      twsStateVariables: Map[String, TransformWithStateVariableInfo] = 
Map.empty
+  ): Map[String, StatePartitionWriterColumnFamilyInfo] = {
+    getLatestColFamilyToSchemaMap(operatorId, storeMetadata, schemaFiles)
+      .map { case (colFamilyName, schema) =>
+        colFamilyName -> StatePartitionWriterColumnFamilyInfo(schema,
+          useMultipleValuesPerKey = twsStateVariables.get(colFamilyName)
+            .map(_.stateVariableType == 
StateVariableType.ListState).getOrElse(false))
+      }
+  }
+
+  private def getLatestColFamilyToSchemaMap(
+      operatorId: Long,
+      storeMetadata: StateStoreMetadata,
+      schemaFiles: List[Path]): Map[String, StateStoreColFamilySchema] = {
+    val storeId = new StateStoreId(
+      stateRootLocation,
+      operatorId,
+      StateStore.PARTITION_ID_TO_CHECK_SCHEMA,
+      storeMetadata.storeName)
+    // using a placeholder runId since we are not running a streaming query
+    val providerId = new StateStoreProviderId(storeId, queryRunId = 
UUID.randomUUID())
+    val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf,
+      oldSchemaFilePaths = schemaFiles)
+    // Read the latest state schema from the provided path for v2 or from the 
dedicated path
+    // for v1
+    manager
+      .readSchemaFile()
+      .map { schema =>
+        schema.colFamilyName -> createKeyEncoderSpecIfAbsent(schema, 
storeMetadata) }.toMap
+  }
+
+  private def createKeyEncoderSpecIfAbsent(
+      colFamilySchema: StateStoreColFamilySchema,
+      storeMetadata: StateStoreMetadata): StateStoreColFamilySchema = {
+    colFamilySchema.keyStateEncoderSpec match {
+      case Some(encoderSpec) => colFamilySchema
+      case None if storeMetadata.isInstanceOf[StateStoreMetadataV1] =>
+        // Create the spec if missing for v1 metadata
+        if (storeMetadata.numColsPrefixKey > 0) {
+          colFamilySchema.copy(keyStateEncoderSpec =
+            Some(PrefixKeyScanStateEncoderSpec(colFamilySchema.keySchema,
+              storeMetadata.numColsPrefixKey)))
+        } else {
+          colFamilySchema.copy(keyStateEncoderSpec =
+            Some(NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema)))
+        }
+      case _ =>
+        // Key encoder spec is expected in v2 metadata
+        throw StateRewriterErrors.missingKeyEncoderSpecError(
+          resolvedCheckpointLocation, colFamilySchema.colFamilyName)
+    }
+  }
+
+  private def getStateVariablesIfTWS(
+      opMetadata: OperatorStateMetadata): Map[String, 
TransformWithStateVariableInfo] = {
+    if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES
+      .contains(opMetadata.operatorInfo.operatorName)) {
+      val operatorProperties = TransformWithStateOperatorProperties.fromJson(
+        
opMetadata.asInstanceOf[OperatorStateMetadataV2].operatorPropertiesJson)
+      operatorProperties.stateVariables.map(s => s.stateName -> s).toMap
+    } else {
+      Map.empty
+    }
+  }
+
+  // Needed only for schema evolution for TWS
+  private def createStoreSchemaProviderIfTWS(
+      opName: String,
+      schemaFiles: List[Path]): Option[StateSchemaProvider] = {
+    if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(opName)) 
{
+      val schemaMetadata = StateSchemaMetadata.createStateSchemaMetadata(
+        stateRootLocation, hadoopConf, schemaFiles.map(_.toString))
+      Some(new InMemoryStateSchemaProvider(schemaMetadata))
+    } else {
+      None
+    }
+  }
+}
+
+/**
+ * Errors thrown by StateRewriter.
+ */
+private[state] object StateRewriterErrors {
+  def missingKeyEncoderSpecError(
+      checkpointLocation: String,
+      colFamilyName: String): StateRewriterInvalidCheckpointError = {
+    new StateRewriterMissingKeyEncoderSpecError(checkpointLocation, 
colFamilyName)
+  }
+
+  def missingOperatorMetadataError(
+      checkpointLocation: String,
+      batchId: Long): StateRewriterInvalidCheckpointError = {
+    new StateRewriterMissingOperatorMetadataError(checkpointLocation, batchId)
+  }
+
+  def unsupportedStateStoreMetadataVersionError(
+      checkpointLocation: String): StateRewriterInvalidCheckpointError = {
+    new StateRewriterUnsupportedStoreMetadataVersionError(checkpointLocation)
+  }
+}
+
+/**
+ * Base class for exceptions thrown when the checkpoint location is in an 
invalid state
+ * for state rewriting.
+ */
+private[state] abstract class StateRewriterInvalidCheckpointError(
+    checkpointLocation: String,
+    subClass: String,
+    messageParameters: Map[String, String],
+    cause: Throwable = null)
+  extends SparkIllegalStateException(
+    errorClass = s"STATE_REWRITER_INVALID_CHECKPOINT.$subClass",
+    messageParameters = Map("checkpointLocation" -> checkpointLocation) ++ 
messageParameters,
+    cause = cause)
+
+private[state] class StateRewriterMissingKeyEncoderSpecError(
+    checkpointLocation: String,
+    colFamilyName: String)
+  extends StateRewriterInvalidCheckpointError(
+    checkpointLocation,
+    subClass = "MISSING_KEY_ENCODER_SPEC",
+    messageParameters = Map("colFamilyName" -> colFamilyName))
+
+private[state] class StateRewriterMissingOperatorMetadataError(
+    checkpointLocation: String,
+    batchId: Long)
+  extends StateRewriterInvalidCheckpointError(
+    checkpointLocation,
+    subClass = "MISSING_OPERATOR_METADATA",
+    messageParameters = Map("batchId" -> batchId.toString))
+
+private[state] class StateRewriterUnsupportedStoreMetadataVersionError(
+    checkpointLocation: String)
+  extends StateRewriterInvalidCheckpointError(
+    checkpointLocation,
+    subClass = "UNSUPPORTED_STATE_STORE_METADATA_VERSION",
+    messageParameters = Map.empty)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
index 6a1f66262d5f..d8a3bbb65af2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
@@ -666,7 +666,7 @@ object SessionWindowTestUtils {
  */
 object StreamStreamJoinTestUtils {
   // All state store names from SymmetricHashJoinStateManager
-  private val allStoreNames: Seq[String] =
+  val allStoreNames: Seq[String] =
     SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
 
   // Column family names for keyToNumValues stores (derived from 
allStateStoreNames)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
index 9501e4e9e36b..e495db499bfe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
@@ -20,10 +20,8 @@ import java.io.File
 import java.sql.Timestamp
 import java.time.Duration
 
-import org.apache.spark.TaskContext
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
-import 
org.apache.spark.sql.execution.datasources.v2.state.{CompositeKeyAggregationTestUtils,
 DropDuplicatesTestUtils, FlatMapGroupsWithStateTestUtils, 
SessionWindowTestUtils, SimpleAggregationTestUtils, StateDataSourceTestBase, 
StateSourceOptions, StreamStreamJoinTestUtils}
+import 
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceTestBase, 
StateSourceOptions, StreamStreamJoinTestUtils}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
@@ -34,7 +32,6 @@ import org.apache.spark.sql.streaming.{InputEvent, 
ListStateTTLProcessor, MapInp
 import org.apache.spark.sql.streaming.util.{StreamManualClock, 
TTLProcessorUtils}
 import org.apache.spark.sql.streaming.util.{EventTimeTimerProcessor, 
MultiStateVarProcessor, MultiStateVarProcessorTestUtils, TimerTestUtils}
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.SerializableConfiguration
 
 /**
  * Test suite for StatePartitionAllColumnFamiliesWriter.
@@ -51,68 +48,33 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "2")
   }
 
-  /**
-   * Helper method to create a StateSchemaProvider from column family schema 
map.
-   */
-  private def createStateSchemaProvider(
-      columnFamilyToSchemaMap: Map[String, 
StatePartitionWriterColumnFamilyInfo]
-  ): StateSchemaProvider = {
-    val testSchemaProvider = new TestStateSchemaProvider()
-    columnFamilyToSchemaMap.foreach { case (cfName, cfInfo) =>
-      testSchemaProvider.captureSchema(
-        colFamilyName = cfName,
-        keySchema = cfInfo.schema.keySchema,
-        valueSchema = cfInfo.schema.valueSchema,
-        keySchemaId = cfInfo.schema.keySchemaId,
-        valueSchemaId = cfInfo.schema.valueSchemaId
-      )
-    }
-    testSchemaProvider
-  }
-
   /**
    * Common helper method to perform round-trip test: read state bytes from 
source,
    * write to target, and verify target matches source.
    *
    * @param sourceDir Source checkpoint directory
    * @param targetDir Target checkpoint directory
-   * @param columnFamilyToSchemaMap Map of column family names to their schemas
-   * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
-   * @param columnFamilyToSelectExprs Map of column family names to custom 
selectExprs
-   * @param columnFamilyToStateSourceOptions Map of column family names to 
state source options
+   * @param storeToColumnFamilies Optional store name to its column families
+   * @param storeToColumnFamilyToSelectExprs Map store name to per column 
family custom selectExprs
+   * @param storeToColumnFamilyToStateSourceOptions Map store name to per 
column family
+   *                                                state source options
    */
   private def performRoundTripTest(
       sourceDir: String,
       targetDir: String,
-      columnFamilyToSchemaMap: Map[String, 
StatePartitionWriterColumnFamilyInfo],
-      storeName: Option[String] = None,
-      columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty,
-      columnFamilyToStateSourceOptions: Map[String, Map[String, String]] = 
Map.empty,
+      storeToColumnFamilies: Map[String, List[String]] =
+        Map(StateStoreId.DEFAULT_STORE_NAME -> 
List(StateStore.DEFAULT_COL_FAMILY_NAME)),
+      storeToColumnFamilyToSelectExprs: Map[String, Map[String, Seq[String]]] 
= Map.empty,
+      storeToColumnFamilyToStateSourceOptions: Map[String, Map[String, 
Map[String, String]]] =
+        Map.empty,
       operatorName: String): Unit = {
-
-    val columnFamiliesToValidate: Seq[String] = if 
(columnFamilyToSchemaMap.size > 1) {
-      columnFamilyToSchemaMap.keys.toSeq
-    } else {
-      Seq(StateStore.DEFAULT_COL_FAMILY_NAME)
-    }
-
-    // Step 1: Read from source using AllColumnFamiliesReader (raw bytes)
-    val sourceBytesReader = spark.read
-      .format("statestore")
-      .option(StateSourceOptions.PATH, sourceDir)
-      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
-    val sourceBytesData = (storeName match {
-      case Some(name) => 
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
-      case None => sourceBytesReader
-    }).load()
-
-    // Verify schema of raw bytes
-    val schema = sourceBytesData.schema
-    assert(schema.fieldNames === Array(
-      "partition_key", "key_bytes", "value_bytes", "column_family_name"))
-
-    // Step 2: Write raw bytes to target checkpoint location
     val hadoopConf = spark.sessionState.newHadoopConf()
+    val sourceCpLocation = StreamingUtils.resolvedCheckpointLocation(
+      hadoopConf, sourceDir)
+    val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+      spark, sourceCpLocation)
+    val readBatchId = sourceCheckpointMetadata.commitLog.getLatestBatchId().get
+
     val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
       hadoopConf, targetDir)
     val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
@@ -120,67 +82,56 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     // increase offsetCheckpoint
     val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
     val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
-    val currentBatchId = lastBatch + 1
-    targetCheckpointMetadata.offsetLog.add(currentBatchId, targetOffsetSeq)
-
-    val storeConf: StateStoreConf = StateStoreConf(spark.sessionState.conf)
-    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
-
-    // Create StateSchemaProvider if needed (for Avro encoding)
-    val stateSchemaProvider = if (storeConf.stateStoreEncodingFormat == 
"avro") {
-      Some(createStateSchemaProvider(columnFamilyToSchemaMap))
-    } else {
-      None
-    }
-    val baseConfs: Map[String, String] = spark.sessionState.conf.getAllConfs
-    val putPartitionFunc: Iterator[InternalRow] => Unit = partition => {
-      val newConf = new SQLConf
-      baseConfs.foreach { case (k, v) =>
-        newConf.setConfString(k, v)
-      }
-      val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
-        storeConf,
-        serializableHadoopConf.value,
-        TaskContext.getPartitionId(),
-        targetCpLocation,
-        0,
-        storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
-        currentBatchId,
-        columnFamilyToSchemaMap,
-        operatorName,
-        stateSchemaProvider,
-        newConf
-      )
-      allCFWriter.write(partition)
-    }
-    sourceBytesData.queryExecution.toRdd.foreachPartition(putPartitionFunc)
+    val writeBatchId = lastBatch + 1
+    targetCheckpointMetadata.offsetLog.add(writeBatchId, targetOffsetSeq)
+
+    val rewriter = new StateRewriter(
+      spark,
+      readBatchId,
+      writeBatchId,
+      targetCpLocation,
+      hadoopConf,
+      readResolvedCheckpointLocation = Some(sourceCpLocation),
+      transformFunc = None,
+      writeCheckpointMetadata = Some(targetCheckpointMetadata)
+    )
+    rewriter.run()
 
     // Commit to commitLog
     val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
-    targetCheckpointMetadata.commitLog.add(currentBatchId, latestCommit)
-    val versionToCheck = currentBatchId + 1
-    val storeNamePath = s"state/0/0${storeName.fold("")("/" + _)}"
-    assert(!checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".changelog"))
-    assert(checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".zip"))
+    targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+    val versionToCheck = writeBatchId + 1
+
+    storeToColumnFamilies.foreach { case (storeName, columnFamilies) =>
+      val storeNamePath = if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
+        "state/0/0"
+      } else {
+        s"state/0/0/$storeName"
+      }
+      assert(!checkpointFileExists(new File(targetDir, storeNamePath),
+        versionToCheck, ".changelog"))
+      assert(checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".zip"))
 
-    // Step 3: Validate by reading from both source and target using normal 
reader"
-    // Default selectExprs for most column families
-    val defaultSelectExprs = Seq("key", "value", "partition_id")
+      // Validate by reading from both source and target using normal reader"
+      // Default selectExprs for most column families
+      val defaultSelectExprs = Seq("key", "value", "partition_id")
 
-    columnFamiliesToValidate
+      columnFamilies
       // filtering out "default" for TWS operator because it doesn't contain 
any data
       .filter(cfName => !(cfName == StateStore.DEFAULT_COL_FAMILY_NAME &&
         
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName)
       ))
       .foreach { cfName =>
-        val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName, 
defaultSelectExprs)
-        val readerOptions = columnFamilyToStateSourceOptions.getOrElse(cfName, 
Map.empty)
+        val selectExprs = 
storeToColumnFamilyToSelectExprs.getOrElse(storeName, Map.empty)
+          .getOrElse(cfName, defaultSelectExprs)
+        val readerOptions = 
storeToColumnFamilyToStateSourceOptions.getOrElse(storeName, Map.empty)
+          .getOrElse(cfName, Map.empty)
 
         def readNormalData(dir: String): Array[Row] = {
           var reader = spark.read
             .format("statestore")
             .option(StateSourceOptions.PATH, dir)
-            .option(StateSourceOptions.STORE_NAME, storeName.orNull)
+            .option(StateSourceOptions.STORE_NAME, storeName)
           readerOptions.foreach { case (k, v) => reader = reader.option(k, v) }
           reader.load()
             .selectExpr(selectExprs: _*)
@@ -192,6 +143,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
 
         validateDataMatches(sourceNormalData, targetNormalData)
       }
+    }
   }
 
   /**
@@ -308,15 +260,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             )
           )
 
-          // Step 2: Define schemas based on state version
-          val metadata = 
SimpleAggregationTestUtils.getSchemasWithMetadata(stateVersion)
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(
-              metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
             operatorName = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME
           )
         }
@@ -349,15 +296,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             )
           )
 
-          // Step 2: Define schemas based on state version for composite key
-          val metadata = 
CompositeKeyAggregationTestUtils.getSchemasWithMetadata(stateVersion)
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(
-              metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
             operatorName = StatefulOperatorsUtils.STATE_STORE_SAVE_EXEC_OP_NAME
           )
         }
@@ -396,36 +338,15 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
           )
 
           // Step 2: Test all 4 state stores created by stream-stream join
-          // Test keyToNumValues stores (both left and right)
-          StreamStreamJoinTestUtils.KEY_TO_NUM_VALUES_ALL.foreach { storeName 
=>
-            val metadata = 
StreamStreamJoinTestUtils.getKeyToNumValuesSchemasWithMetadata()
-
-            // Perform round-trip test using common helper
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              createSingleColumnFamilySchemaMap(
-                metadata.keySchema, metadata.valueSchema, 
metadata.encoderSpec),
-              storeName = Some(storeName),
-              operatorName = 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
-            )
-          }
-
-          // Test keyWithIndexToValue stores (both left and right)
-          StreamStreamJoinTestUtils.KEY_WITH_INDEX_ALL.foreach { storeName =>
-            val metadata =
-              
StreamStreamJoinTestUtils.getKeyWithIndexToValueSchemasWithMetadata(stateVersion)
-
-            // Perform round-trip test using common helper
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              createSingleColumnFamilySchemaMap(
-                metadata.keySchema, metadata.valueSchema, 
metadata.encoderSpec),
-              storeName = Some(storeName),
-              operatorName = 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
-            )
-          }
+          val storeToColumnFamilies = StreamStreamJoinTestUtils.allStoreNames
+            .map(s => s -> List(StateStore.DEFAULT_COL_FAMILY_NAME)).toMap
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            storeToColumnFamilies,
+            operatorName = 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
+          )
         }
       }
     }
@@ -458,15 +379,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             CheckLastBatch(("a", 1, 0, false))
           )
 
-          // Step 2: Define schemas for flatMapGroupsWithState
-          val metadata = 
FlatMapGroupsWithStateTestUtils.getSchemasWithMetadata(stateVersion)
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(
-              metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
             operatorName = 
StatefulOperatorsUtils.FLAT_MAP_GROUPS_WITH_STATE_EXEC_OP_NAME
           )
         }
@@ -474,7 +390,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     }
   }
 
-
   /**
    * Helper method to build timer column family schemas and options for
    * RunningCountStatefulProcessorWithProcTimeTimer and EventTimeTimerProcessor
@@ -564,16 +479,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             CheckAnswer(("a", 1))
           )
 
-          // Step 2: Define schemas for dropDuplicatesWithinWatermark
-          val metadata =
-            
DropDuplicatesTestUtils.getDropDuplicatesWithinWatermarkSchemasWithMetadata()
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(
-              metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
             operatorName = 
StatefulOperatorsUtils.DEDUPLICATE_WITHIN_WATERMARK_EXEC_OP_NAME
           )
         }
@@ -595,16 +504,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             CheckAnswer(("a", 1))
           )
 
-          // Step 2: Define schemas for dropDuplicates with column specified
-          val metadata =
-            
DropDuplicatesTestUtils.getDropDuplicatesWithColumnSchemasWithMetadata()
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(
-              metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
             operatorName = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME
           )
         }
@@ -629,16 +532,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             StopStream
           )
 
-          // Step 2: Define schemas for session window aggregation
-          val (keySchema, valueSchema) = SessionWindowTestUtils.getSchemas()
-          // Session window aggregation uses prefix key scanning where 
sessionId is the prefix
-          val keyStateEncoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1)
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(keySchema, valueSchema, 
keyStateEncoderSpec),
             operatorName = 
StatefulOperatorsUtils.SESSION_WINDOW_STATE_STORE_SAVE_EXEC_OP_NAME
           )
         }
@@ -660,15 +557,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             assertNumStateRows(total = 6, updated = 6)
           )
 
-          // Step 2: Define schemas for dropDuplicates (state version 2)
-          val metadata = 
DropDuplicatesTestUtils.getDropDuplicatesSchemasWithMetadata()
-
           // Perform round-trip test using common helper
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            createSingleColumnFamilySchemaMap(
-              metadata.keySchema, metadata.valueSchema, metadata.encoderSpec),
             operatorName = StatefulOperatorsUtils.DEDUPLICATE_EXEC_OP_NAME
           )
         }
@@ -713,15 +605,15 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             runQuery(sourceDir.getAbsolutePath, roundsOfData = 2)
             runQuery(targetDir.getAbsolutePath, roundsOfData = 1)
 
-            val allColFamilyNames = 
StreamStreamJoinTestUtils.KEY_TO_NUM_VALUES_ALL ++
-              StreamStreamJoinTestUtils.KEY_WITH_INDEX_ALL
+            val allColFamilyNames = 
StreamStreamJoinTestUtils.allStoreNames.toList
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              getJoinV3ColumnSchemaMap(),
-              columnFamilyToStateSourceOptions = allColFamilyNames.map {
-                colName => colName -> Map(StateSourceOptions.STORE_NAME -> 
colName)
-              }.toMap,
+              storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> 
allColFamilyNames),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> allColFamilyNames.map {
+                  cfName => cfName -> Map(StateSourceOptions.STORE_NAME -> 
cfName)
+                }.toMap),
               operatorName = 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
             )
           }
@@ -770,14 +662,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             runQuery(targetDir.getAbsolutePath, 1)
 
             val schemas = 
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
-            val columnFamilyToSchemaMap = schemas.map { case (cfName, 
metadata) =>
-              cfName -> createColFamilyInfo(
-                metadata.keySchema,
-                metadata.valueSchema,
-                metadata.encoderSpec,
-                cfName,
-                metadata.useMultipleValuePerKey)
-            }
             val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
               .getColumnFamilyToSelectExprs()
 
@@ -799,9 +683,11 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              columnFamilyToSchemaMap,
-              columnFamilyToSelectExprs = columnFamilyToSelectExprs,
-              columnFamilyToStateSourceOptions = 
columnFamilyToStateSourceOptions,
+              storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> 
schemas.keys.toList),
+              storeToColumnFamilyToSelectExprs =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
               operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
             )
           }
@@ -842,9 +728,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              schemaMap,
-              columnFamilyToSelectExprs = selectExprs,
-              columnFamilyToStateSourceOptions = stateSourceOptions,
+              storeToColumnFamilies =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
+              storeToColumnFamilyToSelectExprs =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
               operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
             )
           }
@@ -890,9 +779,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              schemaMap,
-              columnFamilyToSelectExprs = selectExprs,
-              columnFamilyToStateSourceOptions = sourceOptions,
+              storeToColumnFamilies =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
+              storeToColumnFamilyToSelectExprs =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
               operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
             )
           }
@@ -933,14 +825,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             )
 
             val schemas = 
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
-            val columnFamilyToSchemaMap = schemas.map { case (cfName, 
metadata) =>
-              cfName -> createColFamilyInfo(
-                metadata.keySchema,
-                metadata.valueSchema,
-                metadata.encoderSpec,
-                cfName,
-                metadata.useMultipleValuePerKey)
-            }
 
             val columnFamilyToSelectExprs = Map(
               TTLProcessorUtils.LIST_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
@@ -963,9 +847,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              columnFamilyToSchemaMap,
-              columnFamilyToSelectExprs = columnFamilyToSelectExprs,
-              columnFamilyToStateSourceOptions = 
columnFamilyToStateSourceOptions,
+              storeToColumnFamilies =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+              storeToColumnFamilyToSelectExprs =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
               operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
             )
           }
@@ -1006,14 +893,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             )
 
             val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
-            val columnFamilyToSchemaMap = schemas.map { case (cfName, 
metadata) =>
-              cfName -> createColFamilyInfo(
-                metadata.keySchema,
-                metadata.valueSchema,
-                metadata.encoderSpec,
-                cfName,
-                metadata.useMultipleValuePerKey)
-            }
 
             val columnFamilyToSelectExprs = Map(
               TTLProcessorUtils.MAP_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
@@ -1027,9 +906,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              columnFamilyToSchemaMap,
-              columnFamilyToSelectExprs = columnFamilyToSelectExprs,
-              columnFamilyToStateSourceOptions = 
columnFamilyToStateSourceOptions,
+              storeToColumnFamilies =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+              storeToColumnFamilyToSelectExprs =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
               operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
             )
           }
@@ -1071,14 +953,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             )
 
             val schemas = 
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
-            val columnFamilyToSchemaMap = schemas.map { case (cfName, 
metadata) =>
-              cfName -> createColFamilyInfo(
-                metadata.keySchema,
-                metadata.valueSchema,
-                metadata.encoderSpec,
-                cfName,
-                metadata.useMultipleValuePerKey)
-            }
 
             val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
               cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
@@ -1087,8 +961,10 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
             performRoundTripTest(
               sourceDir.getAbsolutePath,
               targetDir.getAbsolutePath,
-              columnFamilyToSchemaMap,
-              columnFamilyToStateSourceOptions = 
columnFamilyToStateSourceOptions,
+              storeToColumnFamilies =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+              storeToColumnFamilyToStateSourceOptions =
+                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
               operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
             )
           }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to