This is an automated email from the ASF dual-hosted git repository. ashrigondekar pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 28f2a5cc6af0 [SPARK-52171][SS] StateDataSource join implementation for state v3 28f2a5cc6af0 is described below commit 28f2a5cc6af078d6fd841644d4192d76146d0052 Author: Livia Zhu <livia....@databricks.com> AuthorDate: Thu Jul 17 10:50:48 2025 -0700 [SPARK-52171][SS] StateDataSource join implementation for state v3 ### What changes were proposed in this pull request? Add implementation for StateDataSource for state format v3 which uses virtual column families for the 4 join stores. This entails a few changes: * Inferring schema for for joins needs to take in oldSchemaFilePaths for state format v3. * sourceOptions need to be modified when the join store name is specified for state format v3, since the name is no longer the store name but the colFamily name. Subsequent metadata checks must also account for this. * A new joinColFamilyOpt needs to be passed through to the StateReaderInfo, StatePartitionReader, etc so that it can be used to read the correct column family. ### Why are the changes needed? Enable StateDataSource for join version 3. ### Does this PR introduce _any_ user-facing change? Yes. Previously StateDataSource could not be used on checkpoints that use join state version 3, and now it can. ### How was this patch tested? New unit tests and enable disabled unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51004 from liviazhu/liviazhu-db/statedatasourcereader-v3. Authored-by: Livia Zhu <livia....@databricks.com> Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com> --- .../datasources/v2/state/StateDataSource.scala | 232 ++++++++++++++------- .../v2/state/StatePartitionReader.scala | 37 ++-- .../datasources/v2/state/StateScanBuilder.scala | 19 +- .../datasources/v2/state/StateTable.scala | 6 +- .../v2/state/StreamStreamJoinStateHelper.scala | 86 ++++++-- .../StreamStreamJoinStatePartitionReader.scala | 14 +- .../TransformWithStateInPySparkExec.scala | 6 +- .../join/StreamingSymmetricHashJoinExec.scala | 2 +- .../join/SymmetricHashJoinStateManager.scala | 4 +- .../operators/stateful/statefulOperators.scala | 13 ++ .../TransformWithStateExec.scala | 2 +- .../sql/execution/streaming/state/RocksDB.scala | 12 ++ .../state/StateSchemaCompatibilityChecker.scala | 12 +- .../v2/state/StateDataSourceReadSuite.scala | 58 +++++- .../spark/sql/streaming/StreamingJoinSuite.scala | 8 +- 15 files changed, 364 insertions(+), 147 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 937eb1fc042d..9595e1bb71a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, READ_REGISTERED_TIMERS, STATE_VAR_NAME} +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, READ_REGISTERED_TIMERS, STATE_VAR_NAME, STORE_NAME} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.{OffsetSeqMetadata, StreamingQueryCheckpointMetadata, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.{OffsetSeqMetadata, StatefulOperatorsUtils, StreamingQueryCheckpointMetadata, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.DIR_NAME_STATE import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} -import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId, SymmetricHashJoinStateManager} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode import org.apache.spark.sql.types.StructType @@ -51,25 +51,17 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf() - private lazy val serializedHadoopConf = new SerializableConfiguration(hadoopConf) - - // Seq of operator names who uses state schema v3 and TWS related options. - // This Seq was used in checks before reading state schema files. - private val twsShortNameSeq = Seq( - "transformWithStateExec", - "transformWithStateInPandasExec", - "transformWithStateInPySparkExec" - ) - override def shortName(): String = "statestore" override def getTable( schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = { - val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties) + val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, + StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions) + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( + sourceOptions) // The key state encoder spec should be available for all operators except stream-stream joins val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) { @@ -82,25 +74,28 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, stateStoreReaderInfo.transformWithStateVariableInfoOpt, stateStoreReaderInfo.stateStoreColFamilySchemaOpt, - stateStoreReaderInfo.stateSchemaProviderOpt) + stateStoreReaderInfo.stateSchemaProviderOpt, + stateStoreReaderInfo.joinColFamilyOpt) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { - val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options) + val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, + StateSourceOptions.apply(session, hadoopConf, options)) - val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions) + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( + sourceOptions) + val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf) val stateCheckpointLocation = sourceOptions.stateCheckpointLocation try { - // SPARK-51779 TODO: Support stream-stream joins with virtual column families val (keySchema, valueSchema) = sourceOptions.joinSide match { case JoinSideValues.left => StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString, - sourceOptions.operatorId, LeftSide) + sourceOptions.operatorId, LeftSide, oldSchemaFilePaths) case JoinSideValues.right => StreamStreamJoinStateHelper.readKeyValueSchema(session, stateCheckpointLocation.toString, - sourceOptions.operatorId, RightSide) + sourceOptions.operatorId, RightSide, oldSchemaFilePaths) case JoinSideValues.none => // we should have the schema for the state store if joinSide is none @@ -141,19 +136,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging private def runStateVarChecks( sourceOptions: StateSourceOptions, stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = { - if (sourceOptions.stateVarName.isDefined || sourceOptions.readRegisteredTimers) { - // Perform checks for transformWithState operator in case state variable name is provided - require(stateStoreMetadata.size == 1) - val opMetadata = stateStoreMetadata.head - if (!twsShortNameSeq.contains(opMetadata.operatorName)) { - // if we are trying to query state source with state variable name, then the operator - // should be transformWithState - val errorMsg = "Providing state variable names is only supported with the " + - s"transformWithState operator. Found operator=${opMetadata.operatorName}. " + - s"Please remove this option and re-run the query." - throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, errorMsg) - } - + def runTWSChecks(opMetadata: StateMetadataTableEntry): Unit = { // if the operator is transformWithState, but the operator properties are empty, then // the user has not defined any state variables for the operator val operatorProperties = opMetadata.operatorPropertiesJson @@ -183,35 +166,74 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, s"State variable $stateVarName is not defined for the transformWithState operator.") } - } else { - // if the operator is transformWithState, then a state variable argument is mandatory - if (stateStoreMetadata.size == 1 && - twsShortNameSeq.contains(stateStoreMetadata.head.operatorName)) { - throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName") - } } - } - private def getStateStoreMetadata(stateSourceOptions: StateSourceOptions): - Array[StateMetadataTableEntry] = { - val allStateStoreMetadata = new StateMetadataPartitionReader( - stateSourceOptions.stateCheckpointLocation.getParent.toString, - serializedHadoopConf, stateSourceOptions.batchId).stateMetadata.toArray - val stateStoreMetadata = allStateStoreMetadata.filter { entry => - entry.operatorId == stateSourceOptions.operatorId && - entry.stateStoreName == stateSourceOptions.storeName + sourceOptions.stateVarName match { + case Some(name) => + // Check that stateStoreMetadata exists + require(stateStoreMetadata.size == 1) + val opMetadata = stateStoreMetadata.head + opMetadata.operatorName match { + case opName: String if opName == + StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME => + // Verify that the storename is valid + val possibleStoreNames = SymmetricHashJoinStateManager.allStateStoreNames( + LeftSide, RightSide) + if (!possibleStoreNames.contains(name)) { + val errorMsg = s"Store name $name not allowed for join operator. Allowed names are " + + s"$possibleStoreNames. " + + s"Please remove this option and re-run the query." + throw StateDataSourceErrors.invalidOptionValue(STORE_NAME, errorMsg) + } + case opName: String if StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES + .contains(opName) => + runTWSChecks(opMetadata) + case _ => + // if we are trying to query state source with state variable name, then the operator + // should be transformWithState + val errorMsg = "Providing state variable names is only supported with the " + + s"transformWithState operator. Found operator=${opMetadata.operatorName}. " + + s"Please remove this option and re-run the query." + throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, errorMsg) + } + case None => + if (sourceOptions.readRegisteredTimers) { + // Check that stateStoreMetadata exists + require(stateStoreMetadata.size == 1) + val opMetadata = stateStoreMetadata.head + opMetadata.operatorName match { + case opName: String if StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES + .contains(opName) => + runTWSChecks(opMetadata) + case _ => + // if we are trying to query state source with state variable name, then the operator + // should be transformWithState + val errorMsg = "Providing readRegisteredTimers=true is only supported with the " + + s"transformWithState operator. Found operator=${opMetadata.operatorName}. " + + s"Please remove this option and re-run the query." + throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS, errorMsg) + } + } else { + // if the operator is transformWithState, then a state variable argument is mandatory + if (stateStoreMetadata.size == 1 && + StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains( + stateStoreMetadata.head.operatorName)) { + throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName") + } + } } - stateStoreMetadata } private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions): StateStoreReaderInfo = { - val storeMetadata = getStateStoreMetadata(sourceOptions) + val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf) runStateVarChecks(sourceOptions, storeMetadata) + var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None var transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] = None var stateSchemaProvider: Option[StateSchemaProvider] = None + var joinColFamilyOpt: Option[String] = None var timeMode: String = TimeMode.None.toString if (sourceOptions.joinSide == JoinSideValues.none) { @@ -220,34 +242,41 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging // Read the schema file path from operator metadata version v2 onwards // for the transformWithState operator - val oldSchemaFilePaths = if (storeMetadata.length > 0 && storeMetadata.head.version == 2 - && twsShortNameSeq.exists(storeMetadata.head.operatorName.contains)) { - val storeMetadataEntry = storeMetadata.head - val operatorProperties = TransformWithStateOperatorProperties.fromJson( - storeMetadataEntry.operatorPropertiesJson) - timeMode = operatorProperties.timeMode - - if (sourceOptions.readRegisteredTimers) { - stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1 + val oldSchemaFilePaths = if (storeMetadata.length > 0 && storeMetadata.head.version == 2) { + val opName = storeMetadata.head.operatorName + if (StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.exists(opName.contains)) { + val storeMetadataEntry = storeMetadata.head + val operatorProperties = TransformWithStateOperatorProperties.fromJson( + storeMetadataEntry.operatorPropertiesJson) + timeMode = operatorProperties.timeMode + + if (sourceOptions.readRegisteredTimers) { + stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1 + } + + val stateVarInfoList = operatorProperties.stateVariables + .filter(stateVar => stateVar.stateName == stateVarName) + require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " + + s"for state variable $stateVarName in operator ${sourceOptions.operatorId}") + val stateVarInfo = stateVarInfoList.head + transformWithStateVariableInfoOpt = Some(stateVarInfo) + val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths + val stateSchemaMetadata = StateSchemaMetadata.createStateSchemaMetadata( + sourceOptions.stateCheckpointLocation.toString, + hadoopConf, + schemaFilePaths + ) + stateSchemaProvider = Some(new InMemoryStateSchemaProvider(stateSchemaMetadata)) + schemaFilePaths.map(new Path(_)) + } else { + if (opName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) { + joinColFamilyOpt = Some(stateVarName) + } + StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf) } - - val stateVarInfoList = operatorProperties.stateVariables - .filter(stateVar => stateVar.stateName == stateVarName) - require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " + - s"for state variable $stateVarName in operator ${sourceOptions.operatorId}") - val stateVarInfo = stateVarInfoList.head - transformWithStateVariableInfoOpt = Some(stateVarInfo) - val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths - val stateSchemaMetadata = StateSchemaMetadata.createStateSchemaMetadata( - sourceOptions.stateCheckpointLocation.toString, - hadoopConf, - schemaFilePaths - ) - stateSchemaProvider = Some(new InMemoryStateSchemaProvider(stateSchemaMetadata)) - schemaFilePaths.map(new Path(_)) } else { - None - }.toList + StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf) + } try { // Read the actual state schema from the provided path for v2 or from the dedicated path @@ -276,7 +305,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging keyStateEncoderSpecOpt, stateStoreColFamilySchemaOpt, transformWithStateVariableInfoOpt, - stateSchemaProvider + stateSchemaProvider, + joinColFamilyOpt ) } @@ -553,6 +583,27 @@ object StateSourceOptions extends DataSourceOptions { case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation) } } + + // Modifies options due to external data. Returns modified options. + // If this is a join operator specifying a store name using state format v3, + // we need to modify the options. + private[state] def modifySourceOptions( + hadoopConf: Configuration, sourceOptions: StateSourceOptions): StateSourceOptions = { + // If a storeName is specified (e.g. right-keyToNumValues) and v3 is used, + // we are using join with virtual column families not diff stores. Therefore, + // options will be modified to set stateVarName to that storeName and storeName + // to default. + if (sourceOptions.storeName != StateStoreId.DEFAULT_STORE_NAME && + StreamStreamJoinStateHelper.usesVirtualColumnFamilies( + hadoopConf, sourceOptions.stateCheckpointLocation.toString, + sourceOptions.operatorId)) { + sourceOptions.copy( + stateVarName = Some(sourceOptions.storeName), + storeName = StateStoreId.DEFAULT_STORE_NAME) + } else { + sourceOptions + } + } } // Case class to store information around the key state encoder, col family schema and @@ -561,5 +612,28 @@ case class StateStoreReaderInfo( keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo], - stateSchemaProviderOpt: Option[StateSchemaProvider] + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String] // Only used for join op with state format v3 ) + +object StateDataSource { + private def getStateStoreMetadata( + stateSourceOptions: StateSourceOptions, + hadoopConf: Configuration): Array[StateMetadataTableEntry] = { + val allStateStoreMetadata = new StateMetadataPartitionReader( + stateSourceOptions.stateCheckpointLocation.getParent.toString, + new SerializableConfiguration(hadoopConf), stateSourceOptions.batchId).stateMetadata.toArray + val stateStoreMetadata = allStateStoreMetadata.filter { entry => + entry.operatorId == stateSourceOptions.operatorId && + entry.stateStoreName == stateSourceOptions.storeName + } + stateStoreMetadata + } + + def getOldSchemaFilePaths( + stateSourceOptions: StateSourceOptions, + hadoopConf: Configuration): List[Path] = { + val metadata = getStateStoreMetadata(stateSourceOptions, hadoopConf) + metadata.headOption.map(_.stateSchemaFilePaths.map(new Path(_))).getOrElse(List.empty) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 4aa95ad42ec7..ce6dae8933eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -42,7 +42,8 @@ class StatePartitionReaderFactory( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -50,11 +51,11 @@ class StatePartitionReaderFactory( if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, - stateStoreColFamilySchemaOpt, stateSchemaProviderOpt) + stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt) } else { new StatePartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, - stateStoreColFamilySchemaOpt, stateSchemaProviderOpt) + stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt) } } } @@ -71,7 +72,8 @@ abstract class StatePartitionReaderBase( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends PartitionReader[InternalRow] with Logging { // Used primarily as a placeholder for the value schema in the context of // state variables used within the transformWithState operator. @@ -98,11 +100,7 @@ abstract class StatePartitionReaderBase( partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - val useColFamilies = if (stateVariableInfoOpt.isDefined) { - true - } else { - false - } + val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.ListState) @@ -164,10 +162,11 @@ class StatePartitionReader( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, - stateSchemaProviderOpt) { + stateSchemaProviderOpt, joinColFamilyOpt) { private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { @@ -186,17 +185,18 @@ class StatePartitionReader( } override lazy val iter: Iterator[InternalRow] = { - val stateVarName = stateVariableInfoOpt - .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) + val colFamilyName = stateStoreColFamilySchemaOpt + .map(_.colFamilyName).getOrElse( + joinColFamilyOpt.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)) if (stateVariableInfoOpt.isDefined) { val stateVariableInfo = stateVariableInfoOpt.get val stateVarType = stateVariableInfo.stateVariableType - SchemaUtil.processStateEntries(stateVarType, stateVarName, store, + SchemaUtil.processStateEntries(stateVarType, colFamilyName, store, keySchema, partition.partition, partition.sourceOptions) } else { store - .iterator(stateVarName) + .iterator(colFamilyName) .map { pair => SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) } @@ -221,10 +221,11 @@ class StateStoreChangeDataPartitionReader( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, - stateSchemaProviderOpt) { + stateSchemaProviderOpt, joinColFamilyOpt) { private lazy val changeDataReader: NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = { @@ -235,6 +236,8 @@ class StateStoreChangeDataPartitionReader( val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) { Some(stateVariableInfoOpt.get.stateName) + } else if (joinColFamilyOpt.isDefined) { + Some(joinColFamilyOpt.get) } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 3b8dad7a1809..9adabc7096bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -45,9 +45,11 @@ class StateScanBuilder( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) extends ScanBuilder { + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends ScanBuilder { override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf, - keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt) + keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, + joinColFamilyOpt) } /** An implementation of [[InputPartition]] for State Store data source. */ @@ -65,7 +67,8 @@ class StateScan( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends Scan with Batch { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -120,24 +123,28 @@ class StateScan( override def createReaderFactory(): PartitionReaderFactory = sourceOptions.joinSide match { case JoinSideValues.left => val userFacingSchema = schema + val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, + hadoopConfBroadcast.value.value) val stateSchema = StreamStreamJoinStateHelper.readSchema(session, sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, LeftSide, - excludeAuxColumns = false) + oldSchemaFilePaths, excludeAuxColumns = false) new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, userFacingSchema, stateSchema) case JoinSideValues.right => val userFacingSchema = schema + val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, + hadoopConfBroadcast.value.value) val stateSchema = StreamStreamJoinStateHelper.readSchema(session, sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, RightSide, - excludeAuxColumns = false) + oldSchemaFilePaths, excludeAuxColumns = false) new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, userFacingSchema, stateSchema) case JoinSideValues.none => new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, - stateSchemaProviderOpt) + stateSchemaProviderOpt, joinColFamilyOpt) } override def toBatch: Batch = this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 71b18be7fdf5..96614f0613c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -44,7 +44,8 @@ class StateTable( keyStateEncoderSpec: KeyStateEncoderSpec, stateVariableInfoOpt: Option[TransformWithStateVariableInfo], stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], - stateSchemaProviderOpt: Option[StateSchemaProvider]) + stateSchemaProviderOpt: Option[StateSchemaProvider], + joinColFamilyOpt: Option[String]) extends Table with SupportsRead with SupportsMetadataColumns { import StateTable._ @@ -85,7 +86,8 @@ class StateTable( override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, - stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt) + stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, + joinColFamilyOpt) override def properties(): util.Map[String, String] = Map.empty[String, String].asJava diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala index 1a04d24f0048..3abd1924f543 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStateHelper.scala @@ -18,8 +18,12 @@ package org.apache.spark.sql.execution.datasources.v2.state import java.util.UUID +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinSide +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{JoinSide, LeftSide} import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreId, StateStoreProviderId, SymmetricHashJoinStateManager} import org.apache.spark.sql.types.{BooleanType, StructType} @@ -35,52 +39,92 @@ object StreamStreamJoinStateHelper { stateCheckpointLocation: String, operatorId: Int, side: JoinSide, + oldSchemaFilePaths: List[Path], excludeAuxColumns: Boolean = true): StructType = { val (keySchema, valueSchema) = readKeyValueSchema(session, stateCheckpointLocation, - operatorId, side, excludeAuxColumns) + operatorId, side, oldSchemaFilePaths, excludeAuxColumns) new StructType() .add("key", keySchema) .add("value", valueSchema) } + // Returns whether the checkpoint uses stateFormatVersion 3 which uses VCF for the join. + def usesVirtualColumnFamilies( + hadoopConf: Configuration, + stateCheckpointLocation: String, + operatorId: Int): Boolean = { + // If the schema exists for operatorId/partitionId/left-keyToNumValues, it is not + // stateFormatVersion 3. + val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + val storeId = new StateStoreId(stateCheckpointLocation, operatorId, + partitionId, SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).toList.head) + val schemaFilePath = StateSchemaCompatibilityChecker.schemaFile( + storeId.storeCheckpointLocation()) + val fm = CheckpointFileManager.create(schemaFilePath, hadoopConf) + !fm.exists(schemaFilePath) + } + def readKeyValueSchema( session: SparkSession, stateCheckpointLocation: String, operatorId: Int, side: JoinSide, + oldSchemaFilePaths: List[Path], excludeAuxColumns: Boolean = true): (StructType, StructType) = { + val newHadoopConf = session.sessionState.newHadoopConf() + val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA // KeyToNumValuesType, KeyWithIndexToValueType val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList - val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA - val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId, - partitionId, storeNames(0)) - val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues, - UUID.randomUUID()) + val (keySchema, valueSchema) = + if (!usesVirtualColumnFamilies( + newHadoopConf, stateCheckpointLocation, operatorId)) { + val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId, + partitionId, storeNames(0)) + val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues, + UUID.randomUUID()) - val storeIdForKeyWithIndexToValue = new StateStoreId(stateCheckpointLocation, - operatorId, partitionId, storeNames(1)) - val providerIdForKeyWithIndexToValue = new StateStoreProviderId(storeIdForKeyWithIndexToValue, - UUID.randomUUID()) + val storeIdForKeyWithIndexToValue = new StateStoreId(stateCheckpointLocation, + operatorId, partitionId, storeNames(1)) + val providerIdForKeyWithIndexToValue = new StateStoreProviderId( + storeIdForKeyWithIndexToValue, UUID.randomUUID()) - val newHadoopConf = session.sessionState.newHadoopConf() + // read the key schema from the keyToNumValues store for the join keys + val manager = new StateSchemaCompatibilityChecker( + providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths) + val kSchema = manager.readSchemaFile().head.keySchema + + // read the value schema from the keyWithIndexToValue store for the values + val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue, + newHadoopConf, oldSchemaFilePaths) + val vSchema = manager2.readSchemaFile().head.valueSchema + + (kSchema, vSchema) + } else { + val storeId = new StateStoreId(stateCheckpointLocation, operatorId, + partitionId, StateStoreId.DEFAULT_STORE_NAME) + val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) + + val manager = new StateSchemaCompatibilityChecker( + providerId, newHadoopConf, oldSchemaFilePaths) + val kSchema = manager.readSchemaFile().find { schema => + schema.colFamilyName == storeNames(0) + }.map(_.keySchema).get - // read the key schema from the keyToNumValues store for the join keys - val manager = new StateSchemaCompatibilityChecker(providerIdForKeyToNumValues, newHadoopConf) - val keySchema = manager.readSchemaFile().head.keySchema + val vSchema = manager.readSchemaFile().find { schema => + schema.colFamilyName == storeNames(1) + }.map(_.valueSchema).get - // read the value schema from the keyWithIndexToValue store for the values - val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue, - newHadoopConf) - val valueSchema = manager2.readSchemaFile().head.valueSchema + (kSchema, vSchema) + } val maybeMatchedColumn = valueSchema.last if (excludeAuxColumns - && maybeMatchedColumn.name == "matched" - && maybeMatchedColumn.dataType == BooleanType) { + && maybeMatchedColumn.name == "matched" + && maybeMatchedColumn.dataType == BooleanType) { // remove internal column `matched` for format version 2 (keySchema, StructType(valueSchema.dropRight(1))) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index e1d61de77380..82415b9f30c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -80,8 +80,18 @@ class StreamStreamJoinStatePartitionReader( private val (inputAttributes, formatVersion) = { val maybeMatchedColumn = valueSchema.last val (fields, version) = { + // If there is a matched column, version is either 2 or 3. We need to drop the matched + // column from the value schema to get the actual fields. if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) { - (valueSchema.dropRight(1), 2) + // If checkpoint is using one store and virtual column families, version is 3 + if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies( + hadoopConf.value, + partition.sourceOptions.stateCheckpointLocation.toString, + partition.sourceOptions.operatorId)) { + (valueSchema.dropRight(1), 3) + } else { + (valueSchema.dropRight(1), 2) + } } else { (valueSchema, 1) } @@ -137,7 +147,7 @@ class StreamStreamJoinStatePartitionReader( inputAttributes) joinStateManager.iterator.map { pair => - if (formatVersion == 2) { + if (formatVersion >= 2) { val row = valueWithMatchedRowGenerator(pair.value) row.setBoolean(indexOrdinalInValueWithMatchedRow, pair.matched) unifyStateRowPair(pair.key, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index b65d46fb1632..4cb924313248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowPythonRunner import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, groupAndProject, resolveArgOffsets} -import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorStateInfo, StatefulProcessorHandleImpl, TransformWithStateExecBase, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorStateInfo, StatefulOperatorsUtils, StatefulProcessorHandleImpl, TransformWithStateExecBase, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.internal.SQLConf @@ -95,9 +95,9 @@ case class TransformWithStateInPySparkExec( override def shortName: String = if ( userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS ) { - "transformWithStateInPandasExec" + StatefulOperatorsUtils.TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME } else { - "transformWithStateInPySparkExec" + StatefulOperatorsUtils.TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME } private val pythonUDF = functionExpr.asInstanceOf[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 7d71db8d8e4b..839d610550ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -236,7 +236,7 @@ case class StreamingSymmetricHashJoinExec( case _ => throwBadJoinTypeException() } - override def shortName: String = "symmetricHashJoin" + override def shortName: String = StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME override val stateStoreNames: Seq[String] = _stateStoreNames diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 6ec197d7cc7b..619671d99e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -737,9 +737,9 @@ abstract class SymmetricHashJoinStateManager( if (useVirtualColumnFamilies) { stateStore.createColFamilyIfAbsent( colFamilyName, - keySchema, + keyWithIndexSchema, valueRowConverter.valueAttributes.toStructType, - NoPrefixKeyStateEncoderSpec(keySchema) + NoPrefixKeyStateEncoderSpec(keyWithIndexSchema) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala index d92e5dbae1aa..027b911262ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala @@ -1546,3 +1546,16 @@ trait SchemaValidationUtils extends Logging { schemaEvolutionEnabled = usingAvro && schemaEvolutionEnabledForOperator)) } } + +object StatefulOperatorsUtils { + val TRANSFORM_WITH_STATE_EXEC_OP_NAME = "transformWithStateExec" + val TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME = "transformWithStateInPandasExec" + val TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME = "transformWithStateInPySparkExec" + // Seq of operator names who uses state schema v3 and TWS related options. + val TRANSFORM_WITH_STATE_OP_NAMES: Seq[String] = Seq( + TRANSFORM_WITH_STATE_EXEC_OP_NAME, + TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME, + TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME + ) + val SYMMETRIC_HASH_JOIN_EXEC_OP_NAME = "symmetricHashJoin" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala index 80fdaa1e71e2..db3b7841ac8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala @@ -84,7 +84,7 @@ case class TransformWithStateExec( initialState) with ObjectProducerExec { - override def shortName: String = "transformWithStateExec" + override def shortName: String = StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME // We need to just initialize key and value deserializer once per partition. // The deserializers need to be lazily created on the executor since they diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 6b3bec207703..57a8265524b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -644,6 +644,18 @@ class RocksDB( workingDir, rocksDBFileMapping) loadedVersion = snapshotVersion lastSnapshotVersion = snapshotVersion + + setInitialCFInfo() + metadata.columnFamilyMapping.foreach { mapping => + mapping.foreach { case (colFamilyName, cfId) => + addToColFamilyMaps(colFamilyName, cfId, isInternalColFamily(colFamilyName, metadata)) + } + } + + metadata.maxColumnFamilyId.foreach { maxId => + maxColumnFamilyId.set(maxId) + } + openDB() val (numKeys, numInternalKeys) = if (!conf.trackTotalNumberOfRows) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 17a36e5210b9..08cfd8fb197a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo} import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter} import org.apache.spark.sql.execution.streaming.state.StateSchemaCompatibilityChecker.SCHEMA_FORMAT_V3 -import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.types._ // Result returned after validating the schema of the state store for schema changes @@ -88,7 +88,7 @@ class StateSchemaCompatibilityChecker( // per query. This variable is the latest one private val schemaFileLocation = if (oldSchemaFilePaths.isEmpty) { val storeCpLocation = providerId.storeId.storeCheckpointLocation() - schemaFile(storeCpLocation) + StateSchemaCompatibilityChecker.schemaFile(storeCpLocation) } else { oldSchemaFilePaths.last } @@ -97,7 +97,7 @@ class StateSchemaCompatibilityChecker( fm.mkdirs(schemaFileLocation.getParent) - private val conf = SparkSession.getActiveSession.get.sessionState.conf + private val conf = SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(new SQLConf()) // Read most recent schema file def readSchemaFile(): List[StateStoreColFamilySchema] = { @@ -302,9 +302,6 @@ class StateSchemaCompatibilityChecker( newSchemaFileWritten } } - - private def schemaFile(storeCpLocation: Path): Path = - new Path(new Path(storeCpLocation, "_metadata"), "schema") } object StateSchemaCompatibilityChecker extends Logging { @@ -432,4 +429,7 @@ object StateSchemaCompatibilityChecker extends Logging { StateSchemaValidationResult(evolvedSchema, schemaFileLocation) } + + def schemaFile(storeCpLocation: Path): Path = + new Path(new Path(storeCpLocation, "_metadata"), "schema") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 56a6a1e641f4..dbcf0afb1273 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -526,6 +526,9 @@ class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite { class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite extends StateDataSourceReadSuite { + + import testImplicits._ + override protected def newStateStoreProvider(): RocksDBStateStoreProvider = new RocksDBStateStoreProvider @@ -568,6 +571,49 @@ StateDataSourceReadSuite { testSnapshotOnJoinState("rocksdb", 1) testSnapshotOnJoinState("rocksdb", 2) } + + /** + * Note that we cannot use the golden files approach for transformWithState. The new schema + * format keeps track of the schema file path as an absolute path which cannot be used with + * the getResource model used in other similar tests on runbot. + */ + test("snapshotStartBatchId on join state v3") { + withTempDir { tmpDir => + withSQLConf( + SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100" + ) { + val inputData = MemoryStream[(Int, Long)] + val query = getStreamStreamJoinQuery(inputData) + testStream(query)( + StartStream(checkpointLocation = tmpDir.getCanonicalPath), + AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, + AddData(inputData, (6, 6L), (7, 7L), (8, 8L), (9, 9L), (10, 10L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, + AddData(inputData, (11, 11L), (12, 12L), (13, 13L), (14, 14L), (15, 15L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(5000) }, + StopStream + ) + + val stateSnapshotDf = spark.read.format("statestore") + .option("snapshotPartitionId", 2) + .option("snapshotStartBatchId", 0) + .option("joinSide", "left") + .load(tmpDir.getCanonicalPath) + + val stateDf = spark.read.format("statestore") + .option("joinSide", "left") + .load(tmpDir.getCanonicalPath) + .filter(col("partition_id") === 2) + + checkAnswer(stateSnapshotDf, stateDf) + } + } + } } abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions { @@ -869,6 +915,10 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass testStreamStreamJoin(2) } + test("stream-stream join, state ver 3") { + testStreamStreamJoin(3) + } + private def testStreamStreamJoin(stateVersion: Int): Unit = { def assertInternalColumnIsNotExposed(df: DataFrame): Unit = { val valueSchema = SchemaUtil.getSchemaAsDataType(df.schema, "value") @@ -879,6 +929,12 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } } + // We should only test state version 3 with RocksDBStateStoreProvider + if (stateVersion == 3 + && SQLConf.get.stateStoreProviderClass != classOf[RocksDBStateStoreProvider].getName) { + return + } + withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> stateVersion.toString) { withTempDir { tempDir => runStreamStreamJoinQuery(tempDir.getAbsolutePath) @@ -939,7 +995,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass val stateReadDfForRightKeyWithIndexToValue = stateReaderForRightKeyWithIndexToValue.load() - if (stateVersion == 2) { + if (stateVersion >= 2) { val resultDf4 = stateReadDfForRightKeyWithIndexToValue .selectExpr("key.field0 AS key_0", "key.index AS key_index", "value.rightId AS rightId", "CAST(value.rightTime AS integer) AS rightTime", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 43e064d86117..5c75f82ebef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -696,9 +696,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { }) } - // This does not need to be run with virtual column family joins as it restores the state store - // provider to HDFS and join version to 1, effectively disabling the virtual column family join. - testWithoutVirtualColumnFamilyJoins( + test( "SPARK-26187 restore the stream-stream inner join query from Spark 2.4") { val inputStream = MemoryStream[(Int, Long)] val df = inputStream.toDS() @@ -1524,9 +1522,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { ) } - // This does not need to be run with virtual column family joins as it restores the state store - // provider to HDFS and join version to 1, effectively disabling the virtual column family join. - testWithoutVirtualColumnFamilyJoins( + test( "SPARK-26187 restore the stream-stream outer join query from Spark 2.4") { val inputStream = MemoryStream[(Int, Long)] val df = inputStream.toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org