This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d1d44ff1201 [SPARK-49467][SS] Add support for state data source 
reader and list state
5d1d44ff1201 is described below

commit 5d1d44ff1201af562c87ed2898d67f04e3292683
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Fri Sep 6 17:01:52 2024 +0900

    [SPARK-49467][SS] Add support for state data source reader and list state
    
    ### What changes were proposed in this pull request?
    Add support for state data source reader and list state
    
    ### Why are the changes needed?
    This change adds support for reading state written using list state used 
primarily within the stateful processor used with the `transformWithState` 
operator
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    Users can read state and `explode` entries using the following query:
    ```
            val stateReaderDf = spark.read
              .format("statestore")
              .option(StateSourceOptions.PATH, <checkpoint_location>)
              .option(StateSourceOptions.STATE_VAR_NAME, <state_var_name>)
              .load()
    
            val listStateDf = stateReaderDf
              .selectExpr(
                "key.value AS groupingKey",
                "list_value AS valueList",
                "partition_id")
              .select($"groupingKey",
                explode($"valueList").as("valueList"))
    ```
    
    ### How was this patch tested?
    Added unit tests
    
    ```
    [info] Run completed in 1 minute, 3 seconds.
    [info] Total number of tests run: 8
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 8, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47978 from anishshri-db/task/SPARK-49467.
    
    Authored-by: Anish Shrigondekar <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../v2/state/StatePartitionReader.scala            |  46 ++++--
 .../datasources/v2/state/utils/SchemaUtil.scala    |  55 +++----
 .../StateDataSourceTransformWithStateSuite.scala   | 161 ++++++++++++++++++++-
 3 files changed, 216 insertions(+), 46 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 53576c335cb0..1af2ec174c66 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2.state
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
 import org.apache.spark.sql.execution.streaming.{StateVariableType, 
TransformWithStateVariableInfo}
 import org.apache.spark.sql.execution.streaming.state._
 import 
org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString,
 RecordType}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{NullType, StructField, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.{NextIterator, SerializableConfiguration}
 
@@ -68,10 +69,20 @@ abstract class StatePartitionReaderBase(
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
   extends PartitionReader[InternalRow] with Logging {
+  // Used primarily as a placeholder for the value schema in the context of
+  // state variables used within the transformWithState operator.
+  private val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
   protected val keySchema = SchemaUtil.getSchemaAsDataType(
     schema, "key").asInstanceOf[StructType]
-  protected val valueSchema = SchemaUtil.getSchemaAsDataType(
-    schema, "value").asInstanceOf[StructType]
+
+  protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
+    schemaForValueRow
+  } else {
+    SchemaUtil.getSchemaAsDataType(
+      schema, "value").asInstanceOf[StructType]
+  }
 
   protected lazy val provider: StateStoreProvider = {
     val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
@@ -84,10 +95,17 @@ abstract class StatePartitionReaderBase(
       false
     }
 
+    val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined &&
+      stateVariableInfoOpt.get.stateVariableType == 
StateVariableType.ListState) {
+      true
+    } else {
+      false
+    }
+
     val provider = StateStoreProvider.createAndInit(
       stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
       useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
-      useMultipleValuesPerKey = false)
+      useMultipleValuesPerKey = useMultipleValuesPerKey)
 
     if (useColFamilies) {
       val store = provider.getStore(partition.sourceOptions.batchId + 1)
@@ -99,7 +117,7 @@ abstract class StatePartitionReaderBase(
         stateStoreColFamilySchema.keySchema,
         stateStoreColFamilySchema.valueSchema,
         stateStoreColFamilySchema.keyStateEncoderSpec.get,
-        useMultipleValuesPerKey = false)
+        useMultipleValuesPerKey = useMultipleValuesPerKey)
     }
     provider
   }
@@ -166,16 +184,22 @@ class StatePartitionReader(
         stateVariableInfoOpt match {
           case Some(stateVarInfo) =>
             val stateVarType = stateVarInfo.stateVariableType
-            val hasTTLEnabled = stateVarInfo.ttlEnabled
 
             stateVarType match {
               case StateVariableType.ValueState =>
-                if (hasTTLEnabled) {
-                  SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value), 
valueSchema,
-                    partition.partition)
-                } else {
-                  SchemaUtil.unifyStateRowPair((pair.key, pair.value), 
partition.partition)
+                SchemaUtil.unifyStateRowPair((pair.key, pair.value), 
partition.partition)
+
+              case StateVariableType.ListState =>
+                val key = pair.key
+                val result = store.valuesIterator(key, stateVarName)
+                var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
+                result.foreach { entry =>
+                  unsafeRowArr = unsafeRowArr :+ entry.copy()
                 }
+                // convert the list of values to array type
+                val arrData = new GenericArrayData(unsafeRowArr.toArray)
+                SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, 
arrData),
+                  partition.partition)
 
               case _ =>
                 throw new IllegalStateException(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 9dd357530ec4..47bf9250000a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -19,10 +19,11 @@ package 
org.apache.spark.sql.execution.datasources.v2.state.utils
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import 
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, 
StateSourceOptions}
 import org.apache.spark.sql.execution.streaming.{StateVariableType, 
TransformWithStateVariableInfo}
 import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, 
StringType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, 
StringType, StructType}
 import org.apache.spark.util.ArrayImplicits._
 
 object SchemaUtil {
@@ -70,15 +71,13 @@ object SchemaUtil {
     row
   }
 
-  def unifyStateRowPairWithTTL(
-      pair: (UnsafeRow, UnsafeRow),
-      valueSchema: StructType,
+  def unifyStateRowPairWithMultipleValues(
+      pair: (UnsafeRow, GenericArrayData),
       partition: Int): InternalRow = {
-    val row = new GenericInternalRow(4)
+    val row = new GenericInternalRow(3)
     row.update(0, pair._1)
-    row.update(1, pair._2.get(0, valueSchema))
-    row.update(2, pair._2.get(1, LongType))
-    row.update(3, partition)
+    row.update(1, pair._2)
+    row.update(2, partition)
     row
   }
 
@@ -91,23 +90,22 @@ object SchemaUtil {
       "change_type" -> classOf[StringType],
       "key" -> classOf[StructType],
       "value" -> classOf[StructType],
-      "partition_id" -> classOf[IntegerType],
-      "expiration_timestamp" -> classOf[LongType])
+      "single_value" -> classOf[StructType],
+      "list_value" -> classOf[ArrayType],
+      "partition_id" -> classOf[IntegerType])
 
     val expectedFieldNames = if (sourceOptions.readChangeFeed) {
       Seq("batch_id", "change_type", "key", "value", "partition_id")
     } else if (transformWithStateVariableInfoOpt.isDefined) {
       val stateVarInfo = transformWithStateVariableInfoOpt.get
-      val hasTTLEnabled = stateVarInfo.ttlEnabled
       val stateVarType = stateVarInfo.stateVariableType
 
       stateVarType match {
         case StateVariableType.ValueState =>
-          if (hasTTLEnabled) {
-            Seq("key", "value", "expiration_timestamp", "partition_id")
-          } else {
-            Seq("key", "value", "partition_id")
-          }
+          Seq("key", "single_value", "partition_id")
+
+        case StateVariableType.ListState =>
+          Seq("key", "list_value", "partition_id")
 
         case _ =>
           throw StateDataSourceErrors
@@ -131,24 +129,19 @@ object SchemaUtil {
       stateVarInfo: TransformWithStateVariableInfo,
       stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = {
     val stateVarType = stateVarInfo.stateVariableType
-    val hasTTLEnabled = stateVarInfo.ttlEnabled
 
     stateVarType match {
       case StateVariableType.ValueState =>
-        if (hasTTLEnabled) {
-          val ttlValueSchema = SchemaUtil.getSchemaAsDataType(
-            stateStoreColFamilySchema.valueSchema, 
"value").asInstanceOf[StructType]
-          new StructType()
-            .add("key", stateStoreColFamilySchema.keySchema)
-            .add("value", ttlValueSchema)
-            .add("expiration_timestamp", LongType)
-            .add("partition_id", IntegerType)
-        } else {
-          new StructType()
-            .add("key", stateStoreColFamilySchema.keySchema)
-            .add("value", stateStoreColFamilySchema.valueSchema)
-            .add("partition_id", IntegerType)
-        }
+        new StructType()
+          .add("key", stateStoreColFamilySchema.keySchema)
+          .add("single_value", stateStoreColFamilySchema.valueSchema)
+          .add("partition_id", IntegerType)
+
+      case StateVariableType.ListState =>
+        new StructType()
+          .add("key", stateStoreColFamilySchema.keySchema)
+          .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema))
+          .add("partition_id", IntegerType)
 
       case _ =>
         throw StateDataSourceErrors.internalError(s"Unsupported state variable 
type $stateVarType")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index ccd4e005756a..1c06e4f97f2b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -21,8 +21,9 @@ import java.time.Duration
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, TestClass}
+import org.apache.spark.sql.functions.explode
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, 
RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, 
TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState}
+import org.apache.spark.sql.streaming.{ExpiredTimerInfo, ListState, 
OutputMode, RunningCountStatefulProcessor, StatefulProcessor, 
StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils, 
TTLConfig, ValueState}
 
 /** Stateful processor of single value state var with non-primitive type */
 class StatefulProcessorWithSingleValueVar extends 
RunningCountStatefulProcessor {
@@ -73,6 +74,52 @@ class StatefulProcessorWithTTL
   }
 }
 
+/** Stateful processor tracking groups belonging to sessions with/without TTL 
*/
+class SessionGroupsStatefulProcessor extends
+  StatefulProcessor[String, (String, String), String] {
+  @transient private var _groupsList: ListState[String] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeMode: TimeMode): Unit = {
+    _groupsList = getHandle.getListState("groupsList", Encoders.STRING)
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[(String, String)],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
+    inputRows.foreach { inputRow =>
+      _groupsList.appendValue(inputRow._2)
+    }
+    Iterator.empty
+  }
+}
+
+class SessionGroupsStatefulProcessorWithTTL extends
+  StatefulProcessor[String, (String, String), String] {
+  @transient private var _groupsListWithTTL: ListState[String] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeMode: TimeMode): Unit = {
+    _groupsListWithTTL = getHandle.getListState("groupsListWithTTL", 
Encoders.STRING,
+      TTLConfig(Duration.ofMillis(30000)))
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[(String, String)],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
+    inputRows.foreach { inputRow =>
+      _groupsListWithTTL.appendValue(inputRow._2)
+    }
+    Iterator.empty
+  }
+}
+
 /**
  * Test suite to verify integration of state data source reader with the 
transformWithState operator
  */
@@ -111,7 +158,7 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
 
         val resultDf = stateReaderDf.selectExpr(
           "key.value AS groupingKey",
-          "value.id AS valueId", "value.name AS valueName",
+          "single_value.id AS valueId", "single_value.name AS valueName",
           "partition_id")
 
         checkAnswer(resultDf,
@@ -174,7 +221,7 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
           .load()
 
         val resultDf = stateReaderDf.selectExpr(
-          "key.value", "value.value", "expiration_timestamp", "partition_id")
+          "key.value", "single_value.value", "single_value.ttlExpirationMs", 
"partition_id")
 
         var count = 0L
         resultDf.collect().foreach { row =>
@@ -187,7 +234,7 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
 
         val answerDf = stateReaderDf.selectExpr(
           "key.value AS groupingKey",
-          "value.value AS valueId", "partition_id")
+          "single_value.value.value AS valueId", "partition_id")
         checkAnswer(answerDf,
           Seq(Row("a", 1L, 0), Row("b", 1L, 1)))
 
@@ -217,4 +264,110 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
       }
     }
   }
+
+  test("state data source integration - list state") {
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName) {
+
+        val inputData = MemoryStream[(String, String)]
+        val result = inputData.toDS()
+          .groupByKey(x => x._1)
+          .transformWithState(new SessionGroupsStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, ("session1", "group2")),
+          AddData(inputData, ("session1", "group1")),
+          AddData(inputData, ("session2", "group1")),
+          CheckNewAnswer(),
+          AddData(inputData, ("session3", "group7")),
+          AddData(inputData, ("session1", "group4")),
+          CheckNewAnswer(),
+          StopStream
+        )
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, "groupsList")
+          .load()
+
+        val listStateDf = stateReaderDf
+          .selectExpr(
+      "key.value AS groupingKey",
+            "list_value.value AS valueList",
+            "partition_id")
+          .select($"groupingKey",
+            explode($"valueList"))
+
+        checkAnswer(listStateDf,
+          Seq(Row("session1", "group1"), Row("session1", "group2"), 
Row("session1", "group4"),
+            Row("session2", "group1"), Row("session3", "group7")))
+      }
+    }
+  }
+
+  test("state data source integration - list state and TTL") {
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+        val inputData = MemoryStream[(String, String)]
+        val result = inputData.toDS()
+          .groupByKey(x => x._1)
+          .transformWithState(new SessionGroupsStatefulProcessorWithTTL(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, ("session1", "group2")),
+          AddData(inputData, ("session1", "group1")),
+          AddData(inputData, ("session2", "group1")),
+          AddData(inputData, ("session3", "group7")),
+          AddData(inputData, ("session1", "group4")),
+          Execute { _ =>
+            // wait for the batch to run since we are using processing time
+            Thread.sleep(5000)
+          },
+          StopStream
+        )
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL")
+          .load()
+
+        val listStateDf = stateReaderDf
+          .selectExpr(
+      "key.value AS groupingKey",
+            "list_value AS valueList",
+            "partition_id")
+          .select($"groupingKey",
+            explode($"valueList").as("valueList"))
+
+        val resultDf = listStateDf.selectExpr("valueList.ttlExpirationMs")
+        var count = 0L
+        resultDf.collect().foreach { row =>
+          count = count + 1
+          assert(row.getLong(0) > 0)
+        }
+
+        // verify that 5 state rows are present
+        assert(count === 5)
+
+        val valuesDf = listStateDf.selectExpr("groupingKey",
+          "valueList.value.value AS groupId")
+
+        checkAnswer(valuesDf,
+          Seq(Row("session1", "group1"), Row("session1", "group2"), 
Row("session1", "group4"),
+          Row("session2", "group1"), Row("session3", "group7")))
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to