This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new e6dbc9e537ae [SPARK-51249][SS] Fixing the NoPrefixKeyStateEncoder and Avro encoding to use the correct number of version bytes e6dbc9e537ae is described below commit e6dbc9e537aecdefcdd20ecaa4b98b8a257860eb Author: Eric Marnadi <eric.marn...@databricks.com> AuthorDate: Fri Feb 21 09:53:10 2025 +0900 [SPARK-51249][SS] Fixing the NoPrefixKeyStateEncoder and Avro encoding to use the correct number of version bytes ### What changes were proposed in this pull request? There are currently two bugs: - The NoPrefixKeyStateEncoder adds an extra version byte to each row when UnsafeRow encoding is used: https://github.com/apache/spark/pull/47107 - Rows written with Avro encoding do not include a version byte: https://github.com/apache/spark/pull/48401 **Neither of these bugs have been released, since these bugs are only triggered with multiple column families, and transformWithState is only using it, which is going to be released for Spark 4.0.0.** This change fixes both of these bugs. ### Why are the changes needed? These changes are needed in order to conform with the expected state row encoding format. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #49996 from ericm-db/SPARK-51249. Lead-authored-by: Eric Marnadi <eric.marn...@databricks.com> Co-authored-by: Eric Marnadi <132308037+ericm...@users.noreply.github.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> (cherry picked from commit 42ab97a3e6e77657ecc5cf6f1ff47805eb08422a) Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../streaming/state/RocksDBStateEncoder.scala | 96 ++++++++++++---------- .../streaming/state/RocksDBStateStoreSuite.scala | 14 ++++ .../execution/streaming/state/RocksDBSuite.scala | 9 +- 3 files changed, 71 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index c7b324ec32e6..cf5f8ba5f2eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -895,6 +895,40 @@ class AvroStateEncoder( out.toByteArray } + /** + * Prepends a version byte to the beginning of a byte array. + * This is used to maintain backward compatibility and version control of + * the state encoding format. + * + * @param bytesToEncode The original byte array to prepend the version byte to + * @return A new byte array with the version byte prepended at the beginning + */ + private[sql] def prependVersionByte(bytesToEncode: Array[Byte]): Array[Byte] = { + val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES) + Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) + Platform.copyMemory( + bytesToEncode, Platform.BYTE_ARRAY_OFFSET, + encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, + bytesToEncode.length) + encodedBytes + } + + /** + * Removes the version byte from the beginning of a byte array. + * This is used when decoding state data to get back to the original encoded format. + * + * @param bytes The byte array containing the version byte at the start + * @return A new byte array with the version byte removed + */ + private[sql] def removeVersionByte(bytes: Array[Byte]): Array[Byte] = { + val resultBytes = new Array[Byte](bytes.length - STATE_ENCODING_NUM_VERSION_BYTES) + Platform.copyMemory( + bytes, STATE_ENCODING_NUM_VERSION_BYTES + Platform.BYTE_ARRAY_OFFSET, + resultBytes, Platform.BYTE_ARRAY_OFFSET, resultBytes.length + ) + resultBytes + } + /** * This method takes a byte array written using Avro encoding, and * deserializes to an UnsafeRow using the Avro deserializer @@ -956,7 +990,7 @@ class AvroStateEncoder( private val out = new ByteArrayOutputStream override def encodeKey(row: UnsafeRow): Array[Byte] = { - keyStateEncoderSpec match { + val keyBytes = keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(_) => val avroRow = encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out) @@ -967,6 +1001,7 @@ class AvroStateEncoder( encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, prefixKeyAvroType, out) case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey") } + prependVersionByte(keyBytes) } override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = { @@ -978,8 +1013,8 @@ class AvroStateEncoder( case _ => throw unsupportedOperationForKeyStateEncoder("encodeRemainingKey") } // prepend stateSchemaId to the remaining key portion - encodeWithStateSchemaId( - StateSchemaIdRow(currentKeySchemaId, avroRow)) + prependVersionByte(encodeWithStateSchemaId( + StateSchemaIdRow(currentKeySchemaId, avroRow))) } /** @@ -1118,16 +1153,18 @@ class AvroStateEncoder( val encoder = EncoderFactory.get().binaryEncoder(out, null) writer.write(record, encoder) encoder.flush() - out.toByteArray + prependVersionByte(out.toByteArray) } override def encodeValue(row: UnsafeRow): Array[Byte] = { val avroRow = encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out) // prepend stateSchemaId to the Avro-encoded value portion - encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow)) + prependVersionByte( + encodeWithStateSchemaId(StateSchemaIdRow(currentValSchemaId, avroRow))) } - override def decodeKey(bytes: Array[Byte]): UnsafeRow = { + override def decodeKey(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) keyStateEncoderSpec match { case NoPrefixKeyStateEncoderSpec(_) => val schemaIdRow = decodeStateSchemaIdRow(bytes) @@ -1141,7 +1178,8 @@ class AvroStateEncoder( } - override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = { + override def decodeRemainingKey(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) val schemaIdRow = decodeStateSchemaIdRow(bytes) keyStateEncoderSpec match { case PrefixKeyScanStateEncoderSpec(_, _) => @@ -1174,7 +1212,8 @@ class AvroStateEncoder( * @throws UnsupportedOperationException if a field's data type is not supported for range * scan decoding */ - override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = { + override def decodePrefixKeyForRangeScan(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType) val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, null) val record = reader.read(null, decoder) @@ -1257,7 +1296,8 @@ class AvroStateEncoder( rowWriter.getRow() } - override def decodeValue(bytes: Array[Byte]): UnsafeRow = { + override def decodeValue(rowBytes: Array[Byte]): UnsafeRow = { + val bytes = removeVersionByte(rowBytes) val schemaIdRow = decodeStateSchemaIdRow(bytes) val writerSchema = getStateSchemaProvider.getSchemaMetadataValue( StateSchemaMetadataKey( @@ -1648,45 +1688,11 @@ class NoPrefixKeyStateEncoder( extends RocksDBKeyStateEncoder with Logging { override def encodeKey(row: UnsafeRow): Array[Byte] = { - if (!useColumnFamilies) { - dataEncoder.encodeKey(row) - } else { - // First encode the row with the data encoder - val rowBytes = dataEncoder.encodeKey(row) - - // Create data array with version byte - val dataWithVersion = new Array[Byte](STATE_ENCODING_NUM_VERSION_BYTES + rowBytes.length) - Platform.putByte(dataWithVersion, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION) - Platform.copyMemory( - rowBytes, Platform.BYTE_ARRAY_OFFSET, - dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - rowBytes.length - ) - - dataWithVersion - } + dataEncoder.encodeKey(row) } override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = { - if (!useColumnFamilies) { - dataEncoder.decodeKey(keyBytes) - } else if (keyBytes == null) { - null - } else { - val dataWithVersion = keyBytes - - // Skip version byte to get to actual data - val dataLength = dataWithVersion.length - STATE_ENCODING_NUM_VERSION_BYTES - - // Extract data bytes and decode using data encoder - val dataBytes = new Array[Byte](dataLength) - Platform.copyMemory( - dataWithVersion, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES, - dataBytes, Platform.BYTE_ARRAY_OFFSET, - dataLength - ) - dataEncoder.decodeKey(dataBytes) - } + dataEncoder.decodeKey(keyBytes) } override def supportPrefixKeyScan: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 4d939db8796b..5aea0077e2aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -77,6 +77,20 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid // Verify the version encoded in first byte of the key and value byte arrays assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION) + + // The test verifies that the actual key-value pair (kv) matches these expected byte patterns + // exactly using sameElements, which ensures the serialization format remains consistent and + // backward compatible. This is particularly important for state storage where the format + // needs to be stable across Spark versions. + val (expectedKey, expectedValue) = if (conf.stateStoreEncodingFormat == "avro") { + (Array(0, 0, 0, 2, 2, 97, 2, 0), Array(0, 0, 0, 2, 2)) + } else { + (Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 24, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 97, 0, 0, 0, 0, 0, 0, 0), + Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0)) + } + assert(kv.key.sameElements(expectedKey)) + assert(kv.value.sameElements(expectedValue)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 7d4614d59973..50240c0605e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -433,7 +433,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow) // Verify schema ID in remaining key bytes - val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(encodedRemainingKey) + val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow( + encoder.removeVersionByte(encodedRemainingKey)) assert(decodedSchemaIdRow.schemaId === 18, "Schema ID not preserved in prefix scan remaining key encoding") } @@ -462,7 +463,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow) // Verify schema ID in remaining key bytes - val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow(encodedRemainingKey) + val decodedSchemaIdRow = encoder.decodeStateSchemaIdRow( + encoder.removeVersionByte(encodedRemainingKey)) assert(decodedSchemaIdRow.schemaId === 24, "Schema ID not preserved in range scan remaining key encoding") @@ -565,7 +567,8 @@ class RocksDBStateEncoderSuite extends SparkFunSuite { val encodedValue = valueEncoder.encodeValue(value) // Verify schema ID was included and preserved - val decodedSchemaIdRow = avroEncoder.decodeStateSchemaIdRow(encodedValue) + val decodedSchemaIdRow = avroEncoder.decodeStateSchemaIdRow( + avroEncoder.removeVersionByte(encodedValue)) assert(decodedSchemaIdRow.schemaId === 42, "Schema ID not preserved in single value encoding") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org