This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 28f2a5cc6af0 [SPARK-52171][SS] StateDataSource join implementation for 
state v3
28f2a5cc6af0 is described below

commit 28f2a5cc6af078d6fd841644d4192d76146d0052
Author: Livia Zhu <livia....@databricks.com>
AuthorDate: Thu Jul 17 10:50:48 2025 -0700

    [SPARK-52171][SS] StateDataSource join implementation for state v3
    
    ### What changes were proposed in this pull request?
    
    Add implementation for StateDataSource for state format v3 which uses 
virtual column families for the 4 join stores. This entails a few changes:
    
    * Inferring schema for for joins needs to take in oldSchemaFilePaths for 
state format v3.
    * sourceOptions need to be modified when the join store name is specified 
for state format v3, since the name is no longer the store name but the 
colFamily name. Subsequent metadata checks must also account for this.
    * A new joinColFamilyOpt needs to be passed through to the StateReaderInfo, 
StatePartitionReader, etc so that it can be used to read the correct column 
family.
    
    ### Why are the changes needed?
    
    Enable StateDataSource for join version 3.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Previously StateDataSource could not be used on checkpoints that use 
join state version 3, and now it can.
    
    ### How was this patch tested?
    
    New unit tests and enable disabled unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51004 from liviazhu/liviazhu-db/statedatasourcereader-v3.
    
    Authored-by: Livia Zhu <livia....@databricks.com>
    Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com>
---
 .../datasources/v2/state/StateDataSource.scala     | 232 ++++++++++++++-------
 .../v2/state/StatePartitionReader.scala            |  37 ++--
 .../datasources/v2/state/StateScanBuilder.scala    |  19 +-
 .../datasources/v2/state/StateTable.scala          |   6 +-
 .../v2/state/StreamStreamJoinStateHelper.scala     |  86 ++++++--
 .../StreamStreamJoinStatePartitionReader.scala     |  14 +-
 .../TransformWithStateInPySparkExec.scala          |   6 +-
 .../join/StreamingSymmetricHashJoinExec.scala      |   2 +-
 .../join/SymmetricHashJoinStateManager.scala       |   4 +-
 .../operators/stateful/statefulOperators.scala     |  13 ++
 .../TransformWithStateExec.scala                   |   2 +-
 .../sql/execution/streaming/state/RocksDB.scala    |  12 ++
 .../state/StateSchemaCompatibilityChecker.scala    |  12 +-
 .../v2/state/StateDataSourceReadSuite.scala        |  58 +++++-
 .../spark/sql/streaming/StreamingJoinSuite.scala   |   8 +-
 15 files changed, 364 insertions(+), 147 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 937eb1fc042d..9595e1bb71a1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -29,14 +29,14 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.DataSourceOptions
 import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
 import org.apache.spark.sql.connector.expressions.Transform
-import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
 READ_REGISTERED_TIMERS, STATE_VAR_NAME}
+import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
 READ_REGISTERED_TIMERS, STATE_VAR_NAME, STORE_NAME}
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
 import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader,
 StateMetadataTableEntry}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
-import org.apache.spark.sql.execution.streaming.{OffsetSeqMetadata, 
StreamingQueryCheckpointMetadata, TimerStateUtils, 
TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
+import org.apache.spark.sql.execution.streaming.{OffsetSeqMetadata, 
StatefulOperatorsUtils, StreamingQueryCheckpointMetadata, TimerStateUtils, 
TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.DIR_NAME_STATE
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
-import 
org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, 
KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, 
StateSchemaMetadata, StateSchemaProvider, StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
+import 
org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, 
KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, 
StateSchemaMetadata, StateSchemaProvider, StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId, 
SymmetricHashJoinStateManager}
 import org.apache.spark.sql.sources.DataSourceRegister
 import org.apache.spark.sql.streaming.TimeMode
 import org.apache.spark.sql.types.StructType
@@ -51,25 +51,17 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
   private lazy val hadoopConf: Configuration = 
session.sessionState.newHadoopConf()
 
-  private lazy val serializedHadoopConf = new 
SerializableConfiguration(hadoopConf)
-
-  // Seq of operator names who uses state schema v3 and TWS related options.
-  // This Seq was used in checks before reading state schema files.
-  private val twsShortNameSeq = Seq(
-    "transformWithStateExec",
-    "transformWithStateInPandasExec",
-    "transformWithStateInPySparkExec"
-  )
-
   override def shortName(): String = "statestore"
 
   override def getTable(
       schema: StructType,
       partitioning: Array[Transform],
       properties: util.Map[String, String]): Table = {
-    val sourceOptions = StateSourceOptions.apply(session, hadoopConf, 
properties)
+    val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
+      StateSourceOptions.apply(session, hadoopConf, properties))
     val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
-    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(sourceOptions)
+    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(
+      sourceOptions)
 
     // The key state encoder spec should be available for all operators except 
stream-stream joins
     val keyStateEncoderSpec = if 
(stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
@@ -82,25 +74,28 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
     new StateTable(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
       stateStoreReaderInfo.transformWithStateVariableInfoOpt,
       stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
-      stateStoreReaderInfo.stateSchemaProviderOpt)
+      stateStoreReaderInfo.stateSchemaProviderOpt,
+      stateStoreReaderInfo.joinColFamilyOpt)
   }
 
   override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
-    val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options)
+    val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
+      StateSourceOptions.apply(session, hadoopConf, options))
 
-    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(sourceOptions)
+    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(
+      sourceOptions)
+    val oldSchemaFilePaths = 
StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
 
     val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
     try {
-      // SPARK-51779 TODO: Support stream-stream joins with virtual column 
families
       val (keySchema, valueSchema) = sourceOptions.joinSide match {
         case JoinSideValues.left =>
           StreamStreamJoinStateHelper.readKeyValueSchema(session, 
stateCheckpointLocation.toString,
-            sourceOptions.operatorId, LeftSide)
+            sourceOptions.operatorId, LeftSide, oldSchemaFilePaths)
 
         case JoinSideValues.right =>
           StreamStreamJoinStateHelper.readKeyValueSchema(session, 
stateCheckpointLocation.toString,
-            sourceOptions.operatorId, RightSide)
+            sourceOptions.operatorId, RightSide, oldSchemaFilePaths)
 
         case JoinSideValues.none =>
           // we should have the schema for the state store if joinSide is none
@@ -141,19 +136,7 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
   private def runStateVarChecks(
       sourceOptions: StateSourceOptions,
       stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
-    if (sourceOptions.stateVarName.isDefined || 
sourceOptions.readRegisteredTimers) {
-      // Perform checks for transformWithState operator in case state variable 
name is provided
-      require(stateStoreMetadata.size == 1)
-      val opMetadata = stateStoreMetadata.head
-      if (!twsShortNameSeq.contains(opMetadata.operatorName)) {
-        // if we are trying to query state source with state variable name, 
then the operator
-        // should be transformWithState
-        val errorMsg = "Providing state variable names is only supported with 
the " +
-          s"transformWithState operator. Found 
operator=${opMetadata.operatorName}. " +
-          s"Please remove this option and re-run the query."
-        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, 
errorMsg)
-      }
-
+    def runTWSChecks(opMetadata: StateMetadataTableEntry): Unit = {
       // if the operator is transformWithState, but the operator properties 
are empty, then
       // the user has not defined any state variables for the operator
       val operatorProperties = opMetadata.operatorPropertiesJson
@@ -183,35 +166,74 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
         throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
           s"State variable $stateVarName is not defined for the 
transformWithState operator.")
       }
-    } else {
-      // if the operator is transformWithState, then a state variable argument 
is mandatory
-      if (stateStoreMetadata.size == 1 &&
-        twsShortNameSeq.contains(stateStoreMetadata.head.operatorName)) {
-        throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName")
-      }
     }
-  }
 
-  private def getStateStoreMetadata(stateSourceOptions: StateSourceOptions):
-    Array[StateMetadataTableEntry] = {
-    val allStateStoreMetadata = new StateMetadataPartitionReader(
-      stateSourceOptions.stateCheckpointLocation.getParent.toString,
-      serializedHadoopConf, stateSourceOptions.batchId).stateMetadata.toArray
-    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
-      entry.operatorId == stateSourceOptions.operatorId &&
-        entry.stateStoreName == stateSourceOptions.storeName
+    sourceOptions.stateVarName match {
+      case Some(name) =>
+        // Check that stateStoreMetadata exists
+        require(stateStoreMetadata.size == 1)
+        val opMetadata = stateStoreMetadata.head
+        opMetadata.operatorName match {
+          case opName: String if opName ==
+            StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME =>
+            // Verify that the storename is valid
+            val possibleStoreNames = 
SymmetricHashJoinStateManager.allStateStoreNames(
+              LeftSide, RightSide)
+            if (!possibleStoreNames.contains(name)) {
+              val errorMsg = s"Store name $name not allowed for join operator. 
Allowed names are " +
+                s"$possibleStoreNames. " +
+                s"Please remove this option and re-run the query."
+              throw StateDataSourceErrors.invalidOptionValue(STORE_NAME, 
errorMsg)
+            }
+          case opName: String if 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES
+            .contains(opName) =>
+            runTWSChecks(opMetadata)
+          case _ =>
+            // if we are trying to query state source with state variable 
name, then the operator
+            // should be transformWithState
+            val errorMsg = "Providing state variable names is only supported 
with the " +
+              s"transformWithState operator. Found 
operator=${opMetadata.operatorName}. " +
+              s"Please remove this option and re-run the query."
+            throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, 
errorMsg)
+        }
+      case None =>
+        if (sourceOptions.readRegisteredTimers) {
+          // Check that stateStoreMetadata exists
+          require(stateStoreMetadata.size == 1)
+          val opMetadata = stateStoreMetadata.head
+          opMetadata.operatorName match {
+            case opName: String if 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES
+              .contains(opName) =>
+              runTWSChecks(opMetadata)
+            case _ =>
+              // if we are trying to query state source with state variable 
name, then the operator
+              // should be transformWithState
+              val errorMsg = "Providing readRegisteredTimers=true is only 
supported with the " +
+                s"transformWithState operator. Found 
operator=${opMetadata.operatorName}. " +
+                s"Please remove this option and re-run the query."
+              throw 
StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS, errorMsg)
+          }
+        } else {
+          // if the operator is transformWithState, then a state variable 
argument is mandatory
+          if (stateStoreMetadata.size == 1 &&
+            StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(
+              stateStoreMetadata.head.operatorName)) {
+            throw 
StateDataSourceErrors.requiredOptionUnspecified("stateVarName")
+          }
+        }
     }
-    stateStoreMetadata
   }
 
   private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
     StateStoreReaderInfo = {
-    val storeMetadata = getStateStoreMetadata(sourceOptions)
+    val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, 
hadoopConf)
     runStateVarChecks(sourceOptions, storeMetadata)
+
     var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
     var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None
     var transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo] = None
     var stateSchemaProvider: Option[StateSchemaProvider] = None
+    var joinColFamilyOpt: Option[String] = None
     var timeMode: String = TimeMode.None.toString
 
     if (sourceOptions.joinSide == JoinSideValues.none) {
@@ -220,34 +242,41 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
       // Read the schema file path from operator metadata version v2 onwards
       // for the transformWithState operator
-      val oldSchemaFilePaths = if (storeMetadata.length > 0 && 
storeMetadata.head.version == 2
-        && twsShortNameSeq.exists(storeMetadata.head.operatorName.contains)) {
-        val storeMetadataEntry = storeMetadata.head
-        val operatorProperties = TransformWithStateOperatorProperties.fromJson(
-          storeMetadataEntry.operatorPropertiesJson)
-        timeMode = operatorProperties.timeMode
-
-        if (sourceOptions.readRegisteredTimers) {
-          stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
+      val oldSchemaFilePaths = if (storeMetadata.length > 0 && 
storeMetadata.head.version == 2) {
+        val opName = storeMetadata.head.operatorName
+        if 
(StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.exists(opName.contains)) {
+          val storeMetadataEntry = storeMetadata.head
+          val operatorProperties = 
TransformWithStateOperatorProperties.fromJson(
+            storeMetadataEntry.operatorPropertiesJson)
+          timeMode = operatorProperties.timeMode
+
+          if (sourceOptions.readRegisteredTimers) {
+            stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
+          }
+
+          val stateVarInfoList = operatorProperties.stateVariables
+            .filter(stateVar => stateVar.stateName == stateVarName)
+          require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
+            s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
+          val stateVarInfo = stateVarInfoList.head
+          transformWithStateVariableInfoOpt = Some(stateVarInfo)
+          val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths
+          val stateSchemaMetadata = 
StateSchemaMetadata.createStateSchemaMetadata(
+            sourceOptions.stateCheckpointLocation.toString,
+            hadoopConf,
+            schemaFilePaths
+          )
+          stateSchemaProvider = Some(new 
InMemoryStateSchemaProvider(stateSchemaMetadata))
+          schemaFilePaths.map(new Path(_))
+        } else {
+          if (opName == 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
+            joinColFamilyOpt = Some(stateVarName)
+          }
+          StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
         }
-
-        val stateVarInfoList = operatorProperties.stateVariables
-          .filter(stateVar => stateVar.stateName == stateVarName)
-        require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
-          s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
-        val stateVarInfo = stateVarInfoList.head
-        transformWithStateVariableInfoOpt = Some(stateVarInfo)
-        val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths
-        val stateSchemaMetadata = 
StateSchemaMetadata.createStateSchemaMetadata(
-          sourceOptions.stateCheckpointLocation.toString,
-          hadoopConf,
-          schemaFilePaths
-        )
-        stateSchemaProvider = Some(new 
InMemoryStateSchemaProvider(stateSchemaMetadata))
-        schemaFilePaths.map(new Path(_))
       } else {
-        None
-      }.toList
+        StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
+      }
 
       try {
         // Read the actual state schema from the provided path for v2 or from 
the dedicated path
@@ -276,7 +305,8 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       keyStateEncoderSpecOpt,
       stateStoreColFamilySchemaOpt,
       transformWithStateVariableInfoOpt,
-      stateSchemaProvider
+      stateSchemaProvider,
+      joinColFamilyOpt
     )
   }
 
@@ -553,6 +583,27 @@ object StateSourceOptions extends DataSourceOptions {
       case None => throw 
StateDataSourceErrors.committedBatchUnavailable(checkpointLocation)
     }
   }
+
+  // Modifies options due to external data. Returns modified options.
+  // If this is a join operator specifying a store name using state format v3,
+  // we need to modify the options.
+  private[state] def modifySourceOptions(
+    hadoopConf: Configuration, sourceOptions: StateSourceOptions): 
StateSourceOptions = {
+    // If a storeName is specified (e.g. right-keyToNumValues) and v3 is used,
+    // we are using join with virtual column families not diff stores. 
Therefore,
+    // options will be modified to set stateVarName to that storeName and 
storeName
+    // to default.
+    if (sourceOptions.storeName != StateStoreId.DEFAULT_STORE_NAME &&
+      StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
+        hadoopConf, sourceOptions.stateCheckpointLocation.toString,
+        sourceOptions.operatorId)) {
+      sourceOptions.copy(
+        stateVarName = Some(sourceOptions.storeName),
+        storeName = StateStoreId.DEFAULT_STORE_NAME)
+    } else {
+      sourceOptions
+    }
+  }
 }
 
 // Case class to store information around the key state encoder, col family 
schema and
@@ -561,5 +612,28 @@ case class StateStoreReaderInfo(
     keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
     transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
-    stateSchemaProviderOpt: Option[StateSchemaProvider]
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String] // Only used for join op with state 
format v3
 )
+
+object StateDataSource {
+  private def getStateStoreMetadata(
+    stateSourceOptions: StateSourceOptions,
+    hadoopConf: Configuration): Array[StateMetadataTableEntry] = {
+    val allStateStoreMetadata = new StateMetadataPartitionReader(
+      stateSourceOptions.stateCheckpointLocation.getParent.toString,
+      new SerializableConfiguration(hadoopConf), 
stateSourceOptions.batchId).stateMetadata.toArray
+    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
+      entry.operatorId == stateSourceOptions.operatorId &&
+        entry.stateStoreName == stateSourceOptions.storeName
+    }
+    stateStoreMetadata
+  }
+
+  def getOldSchemaFilePaths(
+    stateSourceOptions: StateSourceOptions,
+    hadoopConf: Configuration): List[Path] = {
+    val metadata = getStateStoreMetadata(stateSourceOptions, hadoopConf)
+    metadata.headOption.map(_.stateSchemaFilePaths.map(new 
Path(_))).getOrElse(List.empty)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 4aa95ad42ec7..ce6dae8933eb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -42,7 +42,8 @@ class StatePartitionReaderFactory(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider])
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String])
   extends PartitionReaderFactory {
 
   override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
@@ -50,11 +51,11 @@ class StatePartitionReaderFactory(
     if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
       new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
         stateStoreInputPartition, schema, keyStateEncoderSpec, 
stateVariableInfoOpt,
-        stateStoreColFamilySchemaOpt, stateSchemaProviderOpt)
+        stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
     } else {
       new StatePartitionReader(storeConf, hadoopConf,
         stateStoreInputPartition, schema, keyStateEncoderSpec, 
stateVariableInfoOpt,
-        stateStoreColFamilySchemaOpt, stateSchemaProviderOpt)
+        stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
     }
   }
 }
@@ -71,7 +72,8 @@ abstract class StatePartitionReaderBase(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider])
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String])
   extends PartitionReader[InternalRow] with Logging {
   // Used primarily as a placeholder for the value schema in the context of
   // state variables used within the transformWithState operator.
@@ -98,11 +100,7 @@ abstract class StatePartitionReaderBase(
       partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
     val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
 
-    val useColFamilies = if (stateVariableInfoOpt.isDefined) {
-      true
-    } else {
-      false
-    }
+    val useColFamilies = stateVariableInfoOpt.isDefined || 
joinColFamilyOpt.isDefined
 
     val useMultipleValuesPerKey = 
SchemaUtil.checkVariableType(stateVariableInfoOpt,
       StateVariableType.ListState)
@@ -164,10 +162,11 @@ class StatePartitionReader(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider])
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String])
   extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
     keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
-    stateSchemaProviderOpt) {
+    stateSchemaProviderOpt, joinColFamilyOpt) {
 
   private lazy val store: ReadStateStore = {
     partition.sourceOptions.fromSnapshotOptions match {
@@ -186,17 +185,18 @@ class StatePartitionReader(
   }
 
   override lazy val iter: Iterator[InternalRow] = {
-    val stateVarName = stateVariableInfoOpt
-      .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
+    val colFamilyName = stateStoreColFamilySchemaOpt
+      .map(_.colFamilyName).getOrElse(
+        joinColFamilyOpt.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME))
 
     if (stateVariableInfoOpt.isDefined) {
       val stateVariableInfo = stateVariableInfoOpt.get
       val stateVarType = stateVariableInfo.stateVariableType
-      SchemaUtil.processStateEntries(stateVarType, stateVarName, store,
+      SchemaUtil.processStateEntries(stateVarType, colFamilyName, store,
         keySchema, partition.partition, partition.sourceOptions)
     } else {
       store
-        .iterator(stateVarName)
+        .iterator(colFamilyName)
         .map { pair =>
           SchemaUtil.unifyStateRowPair((pair.key, pair.value), 
partition.partition)
         }
@@ -221,10 +221,11 @@ class StateStoreChangeDataPartitionReader(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider])
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String])
   extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
     keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
-    stateSchemaProviderOpt) {
+    stateSchemaProviderOpt, joinColFamilyOpt) {
 
   private lazy val changeDataReader:
     NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = {
@@ -235,6 +236,8 @@ class StateStoreChangeDataPartitionReader(
 
     val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) {
       Some(stateVariableInfoOpt.get.stateName)
+    } else if (joinColFamilyOpt.isDefined) {
+      Some(joinColFamilyOpt.get)
     } else {
       None
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
index 3b8dad7a1809..9adabc7096bb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
@@ -45,9 +45,11 @@ class StateScanBuilder(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider]) extends ScanBuilder {
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String]) extends ScanBuilder {
   override def build(): Scan = new StateScan(session, schema, sourceOptions, 
stateStoreConf,
-    keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, 
stateSchemaProviderOpt)
+    keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, 
stateSchemaProviderOpt,
+    joinColFamilyOpt)
 }
 
 /** An implementation of [[InputPartition]] for State Store data source. */
@@ -65,7 +67,8 @@ class StateScan(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider])
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String])
   extends Scan with Batch {
 
   // A Hadoop Configuration can be about 10 KB, which is pretty big, so 
broadcast it
@@ -120,24 +123,28 @@ class StateScan(
   override def createReaderFactory(): PartitionReaderFactory = 
sourceOptions.joinSide match {
     case JoinSideValues.left =>
       val userFacingSchema = schema
+      val oldSchemaFilePaths = 
StateDataSource.getOldSchemaFilePaths(sourceOptions,
+        hadoopConfBroadcast.value.value)
       val stateSchema = StreamStreamJoinStateHelper.readSchema(session,
         sourceOptions.stateCheckpointLocation.toString, 
sourceOptions.operatorId, LeftSide,
-        excludeAuxColumns = false)
+        oldSchemaFilePaths, excludeAuxColumns = false)
       new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
         hadoopConfBroadcast.value, userFacingSchema, stateSchema)
 
     case JoinSideValues.right =>
       val userFacingSchema = schema
+      val oldSchemaFilePaths = 
StateDataSource.getOldSchemaFilePaths(sourceOptions,
+        hadoopConfBroadcast.value.value)
       val stateSchema = StreamStreamJoinStateHelper.readSchema(session,
         sourceOptions.stateCheckpointLocation.toString, 
sourceOptions.operatorId, RightSide,
-        excludeAuxColumns = false)
+        oldSchemaFilePaths, excludeAuxColumns = false)
       new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
         hadoopConfBroadcast.value, userFacingSchema, stateSchema)
 
     case JoinSideValues.none =>
       new StatePartitionReaderFactory(stateStoreConf, 
hadoopConfBroadcast.value, schema,
         keyStateEncoderSpec, stateVariableInfoOpt, 
stateStoreColFamilySchemaOpt,
-        stateSchemaProviderOpt)
+        stateSchemaProviderOpt, joinColFamilyOpt)
   }
 
   override def toBatch: Batch = this
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
index 71b18be7fdf5..96614f0613c9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
@@ -44,7 +44,8 @@ class StateTable(
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
-    stateSchemaProviderOpt: Option[StateSchemaProvider])
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    joinColFamilyOpt: Option[String])
   extends Table with SupportsRead with SupportsMetadataColumns {
 
   import StateTable._
@@ -85,7 +86,8 @@ class StateTable(
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
     new StateScanBuilder(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
-      stateVariableInfoOpt, stateStoreColFamilySchemaOpt, 
stateSchemaProviderOpt)
+      stateVariableInfoOpt, stateStoreColFamilySchemaOpt, 
stateSchemaProviderOpt,
+      joinColFamilyOpt)
 
   override def properties(): util.Map[String, String] = Map.empty[String, 
String].asJava
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala
index 1a04d24f0048..3abd1924f543 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala
@@ -18,8 +18,12 @@ package org.apache.spark.sql.execution.datasources.v2.state
 
 import java.util.UUID
 
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.sql.SparkSession
-import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinSide
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager
+import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{JoinSide,
 LeftSide}
 import 
org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker,
 StateStore, StateStoreId, StateStoreProviderId, SymmetricHashJoinStateManager}
 import org.apache.spark.sql.types.{BooleanType, StructType}
 
@@ -35,52 +39,92 @@ object StreamStreamJoinStateHelper {
       stateCheckpointLocation: String,
       operatorId: Int,
       side: JoinSide,
+      oldSchemaFilePaths: List[Path],
       excludeAuxColumns: Boolean = true): StructType = {
     val (keySchema, valueSchema) = readKeyValueSchema(session, 
stateCheckpointLocation,
-      operatorId, side, excludeAuxColumns)
+      operatorId, side, oldSchemaFilePaths, excludeAuxColumns)
 
     new StructType()
       .add("key", keySchema)
       .add("value", valueSchema)
   }
 
+  // Returns whether the checkpoint uses stateFormatVersion 3 which uses VCF 
for the join.
+  def usesVirtualColumnFamilies(
+    hadoopConf: Configuration,
+    stateCheckpointLocation: String,
+    operatorId: Int): Boolean = {
+    // If the schema exists for operatorId/partitionId/left-keyToNumValues, it 
is not
+    // stateFormatVersion 3.
+    val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
+    val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
+      partitionId, 
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).toList.head)
+    val schemaFilePath = StateSchemaCompatibilityChecker.schemaFile(
+      storeId.storeCheckpointLocation())
+    val fm = CheckpointFileManager.create(schemaFilePath, hadoopConf)
+    !fm.exists(schemaFilePath)
+  }
+
   def readKeyValueSchema(
       session: SparkSession,
       stateCheckpointLocation: String,
       operatorId: Int,
       side: JoinSide,
+      oldSchemaFilePaths: List[Path],
       excludeAuxColumns: Boolean = true): (StructType, StructType) = {
 
+    val newHadoopConf = session.sessionState.newHadoopConf()
+    val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
     // KeyToNumValuesType, KeyWithIndexToValueType
     val storeNames = 
SymmetricHashJoinStateManager.allStateStoreNames(side).toList
 
-    val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
-    val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, 
operatorId,
-      partitionId, storeNames(0))
-    val providerIdForKeyToNumValues = new 
StateStoreProviderId(storeIdForKeyToNumValues,
-      UUID.randomUUID())
+    val (keySchema, valueSchema) =
+      if (!usesVirtualColumnFamilies(
+        newHadoopConf, stateCheckpointLocation, operatorId)) {
+        val storeIdForKeyToNumValues = new 
StateStoreId(stateCheckpointLocation, operatorId,
+          partitionId, storeNames(0))
+        val providerIdForKeyToNumValues = new 
StateStoreProviderId(storeIdForKeyToNumValues,
+          UUID.randomUUID())
 
-    val storeIdForKeyWithIndexToValue = new 
StateStoreId(stateCheckpointLocation,
-      operatorId, partitionId, storeNames(1))
-    val providerIdForKeyWithIndexToValue = new 
StateStoreProviderId(storeIdForKeyWithIndexToValue,
-      UUID.randomUUID())
+        val storeIdForKeyWithIndexToValue = new 
StateStoreId(stateCheckpointLocation,
+          operatorId, partitionId, storeNames(1))
+        val providerIdForKeyWithIndexToValue = new StateStoreProviderId(
+          storeIdForKeyWithIndexToValue, UUID.randomUUID())
 
-    val newHadoopConf = session.sessionState.newHadoopConf()
+        // read the key schema from the keyToNumValues store for the join keys
+        val manager = new StateSchemaCompatibilityChecker(
+          providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths)
+        val kSchema = manager.readSchemaFile().head.keySchema
+
+        // read the value schema from the keyWithIndexToValue store for the 
values
+        val manager2 = new 
StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue,
+          newHadoopConf, oldSchemaFilePaths)
+        val vSchema = manager2.readSchemaFile().head.valueSchema
+
+        (kSchema, vSchema)
+      } else {
+        val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
+          partitionId, StateStoreId.DEFAULT_STORE_NAME)
+        val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
+
+        val manager = new StateSchemaCompatibilityChecker(
+          providerId, newHadoopConf, oldSchemaFilePaths)
+        val kSchema = manager.readSchemaFile().find { schema =>
+          schema.colFamilyName == storeNames(0)
+        }.map(_.keySchema).get
 
-    // read the key schema from the keyToNumValues store for the join keys
-    val manager = new 
StateSchemaCompatibilityChecker(providerIdForKeyToNumValues, newHadoopConf)
-    val keySchema = manager.readSchemaFile().head.keySchema
+        val vSchema = manager.readSchemaFile().find { schema =>
+          schema.colFamilyName == storeNames(1)
+        }.map(_.valueSchema).get
 
-    // read the value schema from the keyWithIndexToValue store for the values
-    val manager2 = new 
StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue,
-      newHadoopConf)
-    val valueSchema = manager2.readSchemaFile().head.valueSchema
+        (kSchema, vSchema)
+      }
 
     val maybeMatchedColumn = valueSchema.last
 
     if (excludeAuxColumns
-        && maybeMatchedColumn.name == "matched"
-        && maybeMatchedColumn.dataType == BooleanType) {
+      && maybeMatchedColumn.name == "matched"
+      && maybeMatchedColumn.dataType == BooleanType) {
       // remove internal column `matched` for format version 2
       (keySchema, StructType(valueSchema.dropRight(1)))
     } else {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
index e1d61de77380..82415b9f30c5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
@@ -80,8 +80,18 @@ class StreamStreamJoinStatePartitionReader(
   private val (inputAttributes, formatVersion) = {
     val maybeMatchedColumn = valueSchema.last
     val (fields, version) = {
+      // If there is a matched column, version is either 2 or 3. We need to 
drop the matched
+      // column from the value schema to get the actual fields.
       if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType 
== BooleanType) {
-        (valueSchema.dropRight(1), 2)
+        // If checkpoint is using one store and virtual column families, 
version is 3
+        if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
+          hadoopConf.value,
+          partition.sourceOptions.stateCheckpointLocation.toString,
+          partition.sourceOptions.operatorId)) {
+          (valueSchema.dropRight(1), 3)
+        } else {
+          (valueSchema.dropRight(1), 2)
+        }
       } else {
         (valueSchema, 1)
       }
@@ -137,7 +147,7 @@ class StreamStreamJoinStatePartitionReader(
       inputAttributes)
 
     joinStateManager.iterator.map { pair =>
-      if (formatVersion == 2) {
+      if (formatVersion >= 2) {
         val row = valueWithMatchedRowGenerator(pair.value)
         row.setBoolean(indexOrdinalInValueWithMatchedRow, pair.matched)
         unifyStateRowPair(pair.key, row)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index b65d46fb1632..4cb924313248 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{CoGroupedIterator, 
SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.ArrowPythonRunner
 import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, 
groupAndProject, resolveArgOffsets}
-import 
org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, 
StatefulOperatorStateInfo, StatefulProcessorHandleImpl, 
TransformWithStateExecBase, TransformWithStateVariableInfo}
+import 
org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, 
StatefulOperatorStateInfo, StatefulOperatorsUtils, StatefulProcessorHandleImpl, 
TransformWithStateExecBase, TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, 
StateStoreProvider, StateStoreProviderId}
 import org.apache.spark.sql.internal.SQLConf
@@ -95,9 +95,9 @@ case class TransformWithStateInPySparkExec(
   override def shortName: String = if (
     userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS
   ) {
-    "transformWithStateInPandasExec"
+    StatefulOperatorsUtils.TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME
   } else {
-    "transformWithStateInPySparkExec"
+    StatefulOperatorsUtils.TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME
   }
 
   private val pythonUDF = functionExpr.asInstanceOf[PythonUDF]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
index 7d71db8d8e4b..839d610550ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
@@ -236,7 +236,7 @@ case class StreamingSymmetricHashJoinExec(
     case _ => throwBadJoinTypeException()
   }
 
-  override def shortName: String = "symmetricHashJoin"
+  override def shortName: String = 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME
 
   override val stateStoreNames: Seq[String] = _stateStoreNames
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
index 6ec197d7cc7b..619671d99e57 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala
@@ -737,9 +737,9 @@ abstract class SymmetricHashJoinStateManager(
     if (useVirtualColumnFamilies) {
       stateStore.createColFamilyIfAbsent(
         colFamilyName,
-        keySchema,
+        keyWithIndexSchema,
         valueRowConverter.valueAttributes.toStructType,
-        NoPrefixKeyStateEncoderSpec(keySchema)
+        NoPrefixKeyStateEncoderSpec(keyWithIndexSchema)
       )
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
index d92e5dbae1aa..027b911262ea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
@@ -1546,3 +1546,16 @@ trait SchemaValidationUtils extends Logging {
         schemaEvolutionEnabled = usingAvro && 
schemaEvolutionEnabledForOperator))
   }
 }
+
+object StatefulOperatorsUtils {
+  val TRANSFORM_WITH_STATE_EXEC_OP_NAME = "transformWithStateExec"
+  val TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME = 
"transformWithStateInPandasExec"
+  val TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME = 
"transformWithStateInPySparkExec"
+  // Seq of operator names who uses state schema v3 and TWS related options.
+  val TRANSFORM_WITH_STATE_OP_NAMES: Seq[String] = Seq(
+    TRANSFORM_WITH_STATE_EXEC_OP_NAME,
+    TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME,
+    TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME
+  )
+  val SYMMETRIC_HASH_JOIN_EXEC_OP_NAME = "symmetricHashJoin"
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
index 80fdaa1e71e2..db3b7841ac8e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
@@ -84,7 +84,7 @@ case class TransformWithStateExec(
     initialState)
   with ObjectProducerExec {
 
-  override def shortName: String = "transformWithStateExec"
+  override def shortName: String = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
 
   // We need to just initialize key and value deserializer once per partition.
   // The deserializers need to be lazily created on the executor since they
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 6b3bec207703..57a8265524b9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -644,6 +644,18 @@ class RocksDB(
       workingDir, rocksDBFileMapping)
     loadedVersion = snapshotVersion
     lastSnapshotVersion = snapshotVersion
+
+    setInitialCFInfo()
+    metadata.columnFamilyMapping.foreach { mapping =>
+      mapping.foreach { case (colFamilyName, cfId) =>
+        addToColFamilyMaps(colFamilyName, cfId, 
isInternalColFamily(colFamilyName, metadata))
+      }
+    }
+
+    metadata.maxColumnFamilyId.foreach { maxId =>
+      maxColumnFamilyId.set(maxId)
+    }
+
     openDB()
 
     val (numKeys, numInternalKeys) = if (!conf.trackTotalNumberOfRows) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
index 17a36e5210b9..08cfd8fb197a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
 import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, 
StatefulOperatorStateInfo}
 import 
org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, 
SchemaWriter}
 import 
org.apache.spark.sql.execution.streaming.state.StateSchemaCompatibilityChecker.SCHEMA_FORMAT_V3
-import org.apache.spark.sql.internal.SessionState
+import org.apache.spark.sql.internal.{SessionState, SQLConf}
 import org.apache.spark.sql.types._
 
 // Result returned after validating the schema of the state store for schema 
changes
@@ -88,7 +88,7 @@ class StateSchemaCompatibilityChecker(
   // per query. This variable is the latest one
   private val schemaFileLocation = if (oldSchemaFilePaths.isEmpty) {
     val storeCpLocation = providerId.storeId.storeCheckpointLocation()
-    schemaFile(storeCpLocation)
+    StateSchemaCompatibilityChecker.schemaFile(storeCpLocation)
   } else {
     oldSchemaFilePaths.last
   }
@@ -97,7 +97,7 @@ class StateSchemaCompatibilityChecker(
 
   fm.mkdirs(schemaFileLocation.getParent)
 
-  private val conf = SparkSession.getActiveSession.get.sessionState.conf
+  private val conf = 
SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(new SQLConf())
 
   // Read most recent schema file
   def readSchemaFile(): List[StateStoreColFamilySchema] = {
@@ -302,9 +302,6 @@ class StateSchemaCompatibilityChecker(
       newSchemaFileWritten
     }
   }
-
-  private def schemaFile(storeCpLocation: Path): Path =
-    new Path(new Path(storeCpLocation, "_metadata"), "schema")
 }
 
 object StateSchemaCompatibilityChecker extends Logging {
@@ -432,4 +429,7 @@ object StateSchemaCompatibilityChecker extends Logging {
 
     StateSchemaValidationResult(evolvedSchema, schemaFileLocation)
   }
+
+  def schemaFile(storeCpLocation: Path): Path =
+    new Path(new Path(storeCpLocation, "_metadata"), "schema")
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 56a6a1e641f4..dbcf0afb1273 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -526,6 +526,9 @@ class RocksDBStateDataSourceReadSuite extends 
StateDataSourceReadSuite {
 
 class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite extends
 StateDataSourceReadSuite {
+
+  import testImplicits._
+
   override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
     new RocksDBStateStoreProvider
 
@@ -568,6 +571,49 @@ StateDataSourceReadSuite {
     testSnapshotOnJoinState("rocksdb", 1)
     testSnapshotOnJoinState("rocksdb", 2)
   }
+
+  /**
+   * Note that we cannot use the golden files approach for transformWithState. 
The new schema
+   * format keeps track of the schema file path as an absolute path which 
cannot be used with
+   * the getResource model used in other similar tests on runbot.
+   */
+  test("snapshotStartBatchId on join state v3") {
+    withTempDir { tmpDir =>
+      withSQLConf(
+        SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3",
+        SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100"
+      ) {
+        val inputData = MemoryStream[(Int, Long)]
+        val query = getStreamStreamJoinQuery(inputData)
+        testStream(query)(
+          StartStream(checkpointLocation = tmpDir.getCanonicalPath),
+          AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+          ProcessAllAvailable(),
+          Execute { _ => Thread.sleep(2000) },
+          AddData(inputData, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)),
+          ProcessAllAvailable(),
+          Execute { _ => Thread.sleep(2000) },
+          AddData(inputData, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 
15L)),
+          ProcessAllAvailable(),
+          Execute { _ => Thread.sleep(5000) },
+          StopStream
+        )
+
+        val stateSnapshotDf = spark.read.format("statestore")
+          .option("snapshotPartitionId", 2)
+          .option("snapshotStartBatchId", 0)
+          .option("joinSide", "left")
+          .load(tmpDir.getCanonicalPath)
+
+        val stateDf = spark.read.format("statestore")
+          .option("joinSide", "left")
+          .load(tmpDir.getCanonicalPath)
+          .filter(col("partition_id") === 2)
+
+        checkAnswer(stateSnapshotDf, stateDf)
+      }
+    }
+  }
 }
 
 abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with 
Assertions {
@@ -869,6 +915,10 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
     testStreamStreamJoin(2)
   }
 
+  test("stream-stream join, state ver 3") {
+    testStreamStreamJoin(3)
+  }
+
   private def testStreamStreamJoin(stateVersion: Int): Unit = {
     def assertInternalColumnIsNotExposed(df: DataFrame): Unit = {
       val valueSchema = SchemaUtil.getSchemaAsDataType(df.schema, "value")
@@ -879,6 +929,12 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
       }
     }
 
+    // We should only test state version 3 with RocksDBStateStoreProvider
+    if (stateVersion == 3
+      && SQLConf.get.stateStoreProviderClass != 
classOf[RocksDBStateStoreProvider].getName) {
+      return
+    }
+
     withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> 
stateVersion.toString) {
       withTempDir { tempDir =>
         runStreamStreamJoinQuery(tempDir.getAbsolutePath)
@@ -939,7 +995,7 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
 
         val stateReadDfForRightKeyWithIndexToValue = 
stateReaderForRightKeyWithIndexToValue.load()
 
-        if (stateVersion == 2) {
+        if (stateVersion >= 2) {
           val resultDf4 = stateReadDfForRightKeyWithIndexToValue
             .selectExpr("key.field0 AS key_0", "key.index AS key_index",
               "value.rightId AS rightId", "CAST(value.rightTime AS integer) AS 
rightTime",
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 43e064d86117..5c75f82ebef1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -696,9 +696,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
       })
   }
 
-  // This does not need to be run with virtual column family joins as it 
restores the state store
-  // provider to HDFS and join version to 1, effectively disabling the virtual 
column family join.
-  testWithoutVirtualColumnFamilyJoins(
+  test(
     "SPARK-26187 restore the stream-stream inner join query from Spark 2.4") {
     val inputStream = MemoryStream[(Int, Long)]
     val df = inputStream.toDS()
@@ -1524,9 +1522,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite {
     )
   }
 
-  // This does not need to be run with virtual column family joins as it 
restores the state store
-  // provider to HDFS and join version to 1, effectively disabling the virtual 
column family join.
-  testWithoutVirtualColumnFamilyJoins(
+  test(
     "SPARK-26187 restore the stream-stream outer join query from Spark 2.4") {
     val inputStream = MemoryStream[(Int, Long)]
     val df = inputStream.toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to