This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3820ae0f86b9 [SPARK-50855][SS][CONNECT] Spark Connect Support for 
TransformWithState In Scala
3820ae0f86b9 is described below

commit 3820ae0f86b9935061e5eecade6e07f628a805ad
Author: jingz-db <jing.z...@databricks.com>
AuthorDate: Tue Mar 4 14:23:55 2025 +0900

    [SPARK-50855][SS][CONNECT] Spark Connect Support for TransformWithState In 
Scala
    
    ### What changes were proposed in this pull request?
    
    Add Spark connect support for TransformWithState.
    
    ### Why are the changes needed?
    
    We need to implement our newly developed operator TransformWithState in 
connect for the feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    In `TransformWithStateStreamingSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49488 from jingz-db/tws-connect-latest.
    
    Lead-authored-by: jingz-db <jing.z...@databricks.com>
    Co-authored-by: Jing Zhan <135738831+jingz...@users.noreply.github.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/catalyst/plans/logical/TimeMode.scala      |  10 +-
 .../streaming/TransformWithStateConnectSuite.scala | 512 +++++++++++++++++++++
 .../spark/sql/connect/KeyValueGroupedDataset.scala |  76 ++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  70 ++-
 4 files changed, 658 insertions(+), 10 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
index e870a83ec4ae..da454c1c4214 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
@@ -21,7 +21,13 @@ import java.util.Locale
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.streaming.TimeMode
 
-/** TimeMode types used in transformWithState operator */
+/**
+ * TimeMode types used in transformWithState operator
+ *
+ * Note that we need to keep TimeMode.None() named as "NoTime" in case class 
here because a case
+ * class named "None" will introduce naming collision with scala native type 
None. See SPARK-51151
+ * for more info.
+ */
 case object NoTime extends TimeMode
 
 case object ProcessingTime extends TimeMode
@@ -31,7 +37,7 @@ case object EventTime extends TimeMode
 object TimeModes {
   def apply(timeMode: String): TimeMode = {
     timeMode.toLowerCase(Locale.ROOT) match {
-      case "none" =>
+      case "none" | "notime" =>
         NoTime
       case "processingtime" =>
         ProcessingTime
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala
new file mode 100644
index 000000000000..a50f60188a45
--- /dev/null
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala
@@ -0,0 +1,512 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.streaming
+
+import java.io.{BufferedWriter, File, FileWriter}
+import java.nio.file.Paths
+import java.sql.Timestamp
+
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row}
+import org.apache.spark.sql.connect.SparkSession
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{ListState, MapState, OutputMode, 
StatefulProcessor, StatefulProcessorWithInitialState, TimeMode, TimerValues, 
TTLConfig, ValueState}
+import org.apache.spark.sql.types._
+
+case class InputRowForConnectTest(key: String, value: String)
+case class OutputRowForConnectTest(key: String, value: String)
+case class StateRowForConnectTest(count: Long)
+
+// A basic stateful processor which will return the occurrences of key
+class BasicCountStatefulProcessor
+    extends StatefulProcessor[String, InputRowForConnectTest, 
OutputRowForConnectTest]
+    with Logging {
+  @transient protected var _countState: ValueState[StateRowForConnectTest] = _
+
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+    _countState = getHandle.getValueState[StateRowForConnectTest](
+      "countState",
+      Encoders.product[StateRowForConnectTest],
+      TTLConfig.NONE)
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[InputRowForConnectTest],
+      timerValues: TimerValues): Iterator[OutputRowForConnectTest] = {
+    val count = inputRows.toSeq.length + {
+      if (_countState.exists()) {
+        _countState.get().count
+      } else {
+        0L
+      }
+    }
+    _countState.update(StateRowForConnectTest(count))
+    Iterator(OutputRowForConnectTest(key, count.toString))
+  }
+}
+
+// A stateful processor with initial state which will return the occurrences 
of key
+class TestInitialStatefulProcessor
+    extends StatefulProcessorWithInitialState[
+      String,
+      (String, String),
+      (String, String),
+      (String, String, String)]
+    with Logging {
+  @transient protected var _countState: ValueState[Long] = _
+
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong, TTLConfig.NONE)
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[(String, String)],
+      timerValues: TimerValues): Iterator[(String, String)] = {
+    val count = inputRows.toSeq.length + {
+      if (_countState.exists()) {
+        _countState.get()
+      } else {
+        0L
+      }
+    }
+    _countState.update(count)
+    Iterator((key, count.toString))
+  }
+
+  override def handleInitialState(
+      key: String,
+      initialState: (String, String, String),
+      timerValues: TimerValues): Unit = {
+    val count = 1 + {
+      if (_countState.exists()) {
+        _countState.get()
+      } else {
+        0L
+      }
+    }
+    _countState.update(count)
+  }
+}
+
+case class OutputEventTimeRow(key: String, outputTimestamp: Timestamp)
+
+// A stateful processor which will return timestamp of the first item from 
input rows
+class ChainingOfOpsStatefulProcessor
+    extends StatefulProcessor[String, (String, Timestamp), OutputEventTimeRow] 
{
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {}
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[(String, Timestamp)],
+      timerValues: TimerValues): Iterator[OutputEventTimeRow] = {
+    val timestamp = inputRows.next()._2
+    Iterator(OutputEventTimeRow(key, timestamp))
+  }
+}
+
+// A basic stateful processor contains composite state variables and TTL
+class TTLTestStatefulProcessor
+    extends StatefulProcessor[String, (String, String), (String, String)] {
+  import java.time.Duration
+
+  @transient protected var countState: ValueState[Int] = _
+  @transient protected var ttlCountState: ValueState[Int] = _
+  @transient protected var ttlListState: ListState[Int] = _
+  @transient protected var ttlMapState: MapState[String, Int] = _
+
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+    countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt, 
TTLConfig.NONE)
+    ttlCountState = getHandle
+      .getValueState[Int]("ttlCountState", Encoders.scalaInt, 
TTLConfig(Duration.ofMillis(1000)))
+    ttlListState = getHandle
+      .getListState[Int]("ttlListState", Encoders.scalaInt, 
TTLConfig(Duration.ofMillis(1000)))
+    ttlMapState = getHandle.getMapState[String, Int](
+      "ttlMapState",
+      Encoders.STRING,
+      Encoders.scalaInt,
+      TTLConfig(Duration.ofMillis(1000)))
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[(String, String)],
+      timerValues: TimerValues): Iterator[(String, String)] = {
+    val numOfInputRows = inputRows.toSeq.length
+    var count = numOfInputRows
+    var ttlCount = numOfInputRows
+    var ttlListStateCount = numOfInputRows
+    var ttlMapStateCount = numOfInputRows
+
+    if (countState.exists()) {
+      count += countState.get()
+    }
+    if (ttlCountState.exists()) {
+      ttlCount += ttlCountState.get()
+    }
+    if (ttlListState.exists()) {
+      for (value <- ttlListState.get()) {
+        ttlListStateCount += value
+      }
+    }
+    if (ttlMapState.exists()) {
+      ttlMapStateCount = ttlMapState.getValue(key)
+    }
+    countState.update(count)
+    if (key != "0") {
+      ttlCountState.update(ttlCount)
+      ttlListState.put(Array(ttlListStateCount, ttlListStateCount))
+      ttlMapState.updateValue(key, ttlMapStateCount)
+    }
+    val output = List(
+      (s"count-$key", count.toString),
+      (s"ttlCount-$key", ttlCount.toString),
+      (s"ttlListState-$key", ttlListStateCount.toString),
+      (s"ttlMapState-$key", ttlMapStateCount.toString))
+    output.iterator
+  }
+}
+
+class TransformWithStateConnectSuite extends QueryTest with RemoteSparkSession 
with Logging {
+  val testData: Seq[(String, String)] = Seq(("a", "1"), ("b", "1"), ("a", "2"))
+  val twsAdditionalSQLConf = Seq(
+    "spark.sql.streaming.stateStore.providerClass" ->
+      
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
+    "spark.sql.shuffle.partitions" -> "5",
+    "spark.sql.session.timeZone" -> "UTC",
+    "spark.sql.streaming.noDataMicroBatches.enabled" -> "false")
+
+  test("transformWithState - streaming with state variable, case class type") {
+    withSQLConf(twsAdditionalSQLConf: _*) {
+      val session: SparkSession = spark
+      import session.implicits._
+
+      spark.sql("DROP TABLE IF EXISTS my_sink")
+
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        testData
+          .toDS()
+          .toDF("key", "value")
+          .repartition(3)
+          .write
+          .parquet(path)
+
+        val testSchema =
+          StructType(Array(StructField("key", StringType), 
StructField("value", StringType)))
+
+        val q = spark.readStream
+          .schema(testSchema)
+          .option("maxFilesPerTrigger", 1)
+          .parquet(path)
+          .as[InputRowForConnectTest]
+          .groupByKey(x => x.key)
+          .transformWithState[OutputRowForConnectTest](
+            new BasicCountStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+          .writeStream
+          .format("memory")
+          .queryName("my_sink")
+          .start()
+
+        try {
+          q.processAllAvailable()
+          eventually(timeout(30.seconds)) {
+            checkDatasetUnorderly(
+              spark.table("my_sink").toDF().as[(String, String)],
+              ("a", "1"),
+              ("a", "2"),
+              ("b", "1"))
+          }
+        } finally {
+          q.stop()
+          spark.sql("DROP TABLE IF EXISTS my_sink")
+        }
+      }
+    }
+  }
+
+  test("transformWithState - streaming with initial state") {
+    withSQLConf(twsAdditionalSQLConf: _*) {
+      val session: SparkSession = spark
+      import session.implicits._
+
+      spark.sql("DROP TABLE IF EXISTS my_sink")
+
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        testData
+          .toDS()
+          .toDF("key", "value")
+          .repartition(3)
+          .write
+          .parquet(path)
+
+        val testSchema =
+          StructType(Array(StructField("key", StringType), 
StructField("value", StringType)))
+
+        val initDf = Seq(("init_1", "40.0", "a"), ("init_2", "100.0", "b"))
+          .toDS()
+          .groupByKey(x => x._3)
+          .mapValues(x => x)
+
+        val q = spark.readStream
+          .schema(testSchema)
+          .option("maxFilesPerTrigger", 1)
+          .parquet(path)
+          .as[(String, String)]
+          .groupByKey(x => x._1)
+          .transformWithState(
+            new TestInitialStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Update(),
+            initialState = initDf)
+          .writeStream
+          .format("memory")
+          .queryName("my_sink")
+          .start()
+
+        try {
+          q.processAllAvailable()
+          eventually(timeout(30.seconds)) {
+            checkDatasetUnorderly(
+              spark.table("my_sink").toDF().as[(String, String)],
+              ("a", "2"),
+              ("a", "3"),
+              ("b", "2"))
+          }
+        } finally {
+          q.stop()
+          spark.sql("DROP TABLE IF EXISTS my_sink")
+        }
+      }
+    }
+  }
+
+  test("transformWithState - streaming with chaining of operators") {
+    withSQLConf(twsAdditionalSQLConf: _*) {
+      val session: SparkSession = spark
+      import session.implicits._
+
+      def timestamp(num: Int): Timestamp = {
+        new Timestamp(num * 1000)
+      }
+
+      val checkResultFunc: (Dataset[Row], Long) => Unit = { (batchDF, batchId) 
=>
+        val realDf = batchDF.collect().toSet
+        if (batchId == 0) {
+          assert(realDf.isEmpty, s"BatchId: $batchId, RealDF: $realDf")
+        } else if (batchId == 1) {
+          // eviction watermark = 15 - 5 = 10 (max event time from batch 0),
+          // late event watermark = 0 (eviction event time from batch 0)
+          val expectedDF = Seq(Row(timestamp(10), 1L)).toSet
+          assert(
+            realDf == expectedDF,
+            s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf")
+        } else if (batchId == 2) {
+          // eviction watermark = 25 - 5 = 20, late event watermark = 10;
+          // row with watermark=5<10 is dropped so it does not show up in the 
results;
+          // row with eventTime<=20 are finalized and emitted
+          val expectedDF = Seq(Row(timestamp(11), 1L), Row(timestamp(15), 
1L)).toSet
+          assert(
+            realDf == expectedDF,
+            s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf")
+        }
+      }
+
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        val curTime = System.currentTimeMillis
+        val file1 = prepareInputData(path + "/text-test3.csv", Seq("a", "b"), 
Seq(10, 15))
+        file1.setLastModified(curTime + 2L)
+        val file2 = prepareInputData(path + "/text-test4.csv", Seq("a", "c"), 
Seq(11, 25))
+        file2.setLastModified(curTime + 4L)
+        val file3 = prepareInputData(path + "/text-test1.csv", Seq("a"), 
Seq(5))
+        file3.setLastModified(curTime + 6L)
+
+        val q = buildTestDf(path, spark)
+          .select(col("key").as("key"), 
timestamp_seconds(col("value")).as("eventTime"))
+          .withWatermark("eventTime", "5 seconds")
+          .as[(String, Timestamp)]
+          .groupByKey(x => x._1)
+          .transformWithState[OutputEventTimeRow](
+            new ChainingOfOpsStatefulProcessor(),
+            "outputTimestamp",
+            OutputMode.Append())
+          .groupBy("outputTimestamp")
+          .count()
+          .writeStream
+          .foreachBatch(checkResultFunc)
+          .outputMode("Append")
+          .start()
+
+        q.processAllAvailable()
+        eventually(timeout(30.seconds)) {
+          q.stop()
+        }
+      }
+    }
+  }
+
+  test("transformWithState - streaming with TTL and composite state 
variables") {
+    withSQLConf(twsAdditionalSQLConf: _*) {
+      val session: SparkSession = spark
+      import session.implicits._
+
+      val checkResultFunc = (batchDF: Dataset[(String, String)], batchId: 
Long) => {
+        if (batchId == 0) {
+          val expectedDF = Set(
+            ("count-0", "1"),
+            ("ttlCount-0", "1"),
+            ("ttlListState-0", "1"),
+            ("ttlMapState-0", "1"),
+            ("count-1", "1"),
+            ("ttlCount-1", "1"),
+            ("ttlListState-1", "1"),
+            ("ttlMapState-1", "1"))
+
+          val realDf = batchDF.collect().toSet
+          assert(realDf == expectedDF)
+
+        } else if (batchId == 1) {
+          val expectedDF = Set(
+            ("count-0", "2"),
+            ("ttlCount-0", "1"),
+            ("ttlListState-0", "1"),
+            ("ttlMapState-0", "1"),
+            ("count-1", "2"),
+            ("ttlCount-1", "1"),
+            ("ttlListState-1", "1"),
+            ("ttlMapState-1", "1"))
+
+          val realDf = batchDF.collect().toSet
+          assert(realDf == expectedDF)
+        }
+
+        if (batchId == 0) {
+          // let ttl state expires
+          Thread.sleep(2000)
+        }
+      }
+
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        val curTime = System.currentTimeMillis
+        val file1 = prepareInputData(path + "/text-test3.csv", Seq("1", "0"), 
Seq(0, 0))
+        file1.setLastModified(curTime + 2L)
+        val file2 = prepareInputData(path + "/text-test4.csv", Seq("1", "0"), 
Seq(0, 0))
+        file2.setLastModified(curTime + 4L)
+
+        val q = buildTestDf(path, spark)
+          .as[(String, String)]
+          .groupByKey(x => x._1)
+          .transformWithState(
+            new TTLTestStatefulProcessor(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+          .writeStream
+          .foreachBatch(checkResultFunc)
+          .outputMode("Update")
+          .start()
+        q.processAllAvailable()
+
+        eventually(timeout(30.seconds)) {
+          q.stop()
+        }
+      }
+    }
+  }
+
+  test("transformWithState - batch query") {
+    withSQLConf(twsAdditionalSQLConf: _*) {
+      val session: SparkSession = spark
+      import session.implicits._
+
+      spark.sql("DROP TABLE IF EXISTS my_sink")
+
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        testData
+          .toDS()
+          .toDF("key", "value")
+          .repartition(3)
+          .write
+          .parquet(path)
+
+        val testSchema =
+          StructType(Array(StructField("key", StringType), 
StructField("value", StringType)))
+
+        spark.read
+          .schema(testSchema)
+          .parquet(path)
+          .as[InputRowForConnectTest]
+          .groupByKey(x => x.key)
+          .transformWithState[OutputRowForConnectTest](
+            new BasicCountStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+          .write
+          .saveAsTable("my_sink")
+
+        checkDatasetUnorderly(
+          spark.table("my_sink").toDF().as[(String, String)],
+          ("a", "2"),
+          ("b", "1"))
+      }
+    }
+  }
+
+  /* Utils functions for tests */
+  def prepareInputData(inputPath: String, col1: Seq[String], col2: Seq[Int]): 
File = {
+    // Ensure the parent directory exists
+    val file = Paths.get(inputPath).toFile
+    val parentDir = file.getParentFile
+    if (parentDir != null && !parentDir.exists()) {
+      parentDir.mkdirs()
+    }
+
+    val writer = new BufferedWriter(new FileWriter(inputPath))
+    try {
+      col1.zip(col2).foreach { case (e1, e2) =>
+        writer.write(s"$e1, $e2\n")
+      }
+    } finally {
+      writer.close()
+    }
+    file
+  }
+
+  def buildTestDf(inputPath: String, sparkSession: SparkSession): DataFrame = {
+    sparkSession.readStream
+      .format("csv")
+      .schema(
+        new StructType()
+          .add(StructField("key", StringType))
+          .add(StructField("value", StringType)))
+      .option("maxFilesPerTrigger", 1)
+      .load(inputPath)
+      .select(col("key").as("key"), col("value").cast("integer"))
+  }
+}
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
index b15e8c28df74..090907a538c7 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
@@ -141,7 +141,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends 
sql.KeyValueGroupedDa
       statefulProcessor: StatefulProcessor[K, V, U],
       timeMode: TimeMode,
       outputMode: OutputMode): Dataset[U] =
-    unsupported()
+    transformWithStateHelper(statefulProcessor, timeMode, outputMode)
 
   /** @inheritdoc */
   private[sql] def transformWithState[U: Encoder, S: Encoder](
@@ -149,20 +149,40 @@ class KeyValueGroupedDataset[K, V] private[sql] () 
extends sql.KeyValueGroupedDa
       timeMode: TimeMode,
       outputMode: OutputMode,
       initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] =
-    unsupported()
+    transformWithStateHelper(statefulProcessor, timeMode, outputMode, 
Some(initialState))
 
   /** @inheritdoc */
   override private[sql] def transformWithState[U: Encoder](
       statefulProcessor: StatefulProcessor[K, V, U],
       eventTimeColumnName: String,
-      outputMode: OutputMode): Dataset[U] = unsupported()
+      outputMode: OutputMode): Dataset[U] =
+    transformWithStateHelper(
+      statefulProcessor,
+      TimeMode.EventTime(),
+      outputMode,
+      eventTimeColumnName = eventTimeColumnName)
 
   /** @inheritdoc */
   override private[sql] def transformWithState[U: Encoder, S: Encoder](
       statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
       eventTimeColumnName: String,
       outputMode: OutputMode,
-      initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = 
unsupported()
+      initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] =
+    transformWithStateHelper(
+      statefulProcessor,
+      TimeMode.EventTime(),
+      outputMode,
+      Some(initialState),
+      eventTimeColumnName)
+
+  // This is an interface, and it should not be used. The real implementation 
is in the
+  // inherited class.
+  protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None,
+      eventTimeColumnName: String = ""): Dataset[U] = unsupported()
 
   // Overrides...
   /** @inheritdoc */
@@ -602,7 +622,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     }
 
     val initialStateImpl = if (initialState.isDefined) {
-      assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, 
_]])
       initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
     } else {
       null
@@ -632,6 +651,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     }
   }
 
+  override protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessor[K, V, U],
+      timeMode: TimeMode,
+      outputMode: OutputMode,
+      initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None,
+      eventTimeColumnName: String = ""): Dataset[U] = {
+    val outputEncoder = agnosticEncoderFor[U]
+    val stateEncoder = agnosticEncoderFor[S]
+    val inputEncoders: Seq[AgnosticEncoder[_]] = Seq(kEncoder, stateEncoder, 
ivEncoder)
+
+    // SparkUserDefinedFunction is creating a udfPacket where the input 
function are
+    // being java serialized into bytes; we pass in `statefulProcessor` as 
function so it can be
+    // serialized into bytes and deserialized back on connect server
+    val sparkUserDefinedFunc =
+      SparkUserDefinedFunction(statefulProcessor, inputEncoders, outputEncoder)
+    val funcProto = UdfToProtoUtils.toProto(sparkUserDefinedFunc)
+
+    val initialStateImpl = if (initialState.isDefined) {
+      initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
+    } else {
+      null
+    }
+
+    sparkSession.newDataset[U](outputEncoder) { builder =>
+      val twsBuilder = builder.getGroupMapBuilder
+      val twsInfoBuilder = proto.TransformWithStateInfo.newBuilder()
+      if (!eventTimeColumnName.isEmpty) {
+        twsInfoBuilder.setEventTimeColumnName(eventTimeColumnName)
+      }
+      twsBuilder
+        .setInput(plan.getRoot)
+        .addAllGroupingExpressions(groupingExprs)
+        .setFunc(funcProto)
+        .setOutputMode(outputMode.toString)
+        .setTransformWithStateInfo(
+          twsInfoBuilder
+            // we pass time mode as string here and deterministically restored 
on server
+            .setTimeMode(timeMode.toString)
+            .build())
+      if (initialStateImpl != null) {
+        twsBuilder
+          .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs)
+          .setInitialInput(initialStateImpl.plan.getRoot)
+      }
+    }
+  }
+
   private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: 
AgnosticEncoder[U])(
       inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction 
= {
     val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by 
setting kEncoder
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 24fc1275d482..734eb394ca68 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -54,7 +54,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, 
CoGroup, CollectMetrics, CommandResult, Deduplicate, 
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, 
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, 
TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...]
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, 
CoGroup, CollectMetrics, CommandResult, Deduplicate, 
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, 
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, 
TimeModes, TransformWithState, TypedFilter, Union, Unpivot, Unresol [...]
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -81,7 +81,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
 import 
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
 import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
 import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, 
SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, 
StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, 
StatefulProcessor, StatefulProcessorWithInitialState, StreamingQuery, 
StreamingQueryListener, StreamingQueryProgress, Trigger}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.{ArrowUtils, CaseInsensitiveStringMap}
 import org.apache.spark.storage.CacheId
@@ -684,7 +684,71 @@ class SparkConnectPlanner(
       rel.getGroupingExpressionsList,
       rel.getSortingExpressionsList)
 
-    if (rel.hasIsMapGroupsWithState) {
+    if (rel.hasTransformWithStateInfo) {
+      val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && 
rel.hasInitialInput
+
+      val twsInfo = rel.getTransformWithStateInfo
+      val keyDeserializer = udf.inputDeserializer(ds.groupingAttributes)
+      val outputAttr = udf.outputObjAttr
+
+      val timeMode = TimeModes(twsInfo.getTimeMode)
+      val outputMode = InternalOutputModes(rel.getOutputMode)
+
+      val twsNode = if (hasInitialState) {
+        val statefulProcessor = unpackedUdf.function
+          .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
+        val initDs = UntypedKeyValueGroupedDataset(
+          rel.getInitialInput,
+          rel.getInitialGroupingExpressionsList,
+          rel.getSortingExpressionsList)
+        new TransformWithState(
+          keyDeserializer,
+          ds.valueDeserializer,
+          ds.groupingAttributes,
+          ds.dataAttributes,
+          statefulProcessor,
+          timeMode,
+          outputMode,
+          udf.inEnc.asInstanceOf[ExpressionEncoder[Any]],
+          outputAttr,
+          ds.analyzed,
+          hasInitialState,
+          initDs.groupingAttributes,
+          initDs.dataAttributes,
+          initDs.valueDeserializer,
+          initDs.analyzed)
+      } else {
+        val statefulProcessor =
+          unpackedUdf.function.asInstanceOf[StatefulProcessor[Any, Any, Any]]
+        new TransformWithState(
+          keyDeserializer,
+          ds.valueDeserializer,
+          ds.groupingAttributes,
+          ds.dataAttributes,
+          statefulProcessor,
+          timeMode,
+          outputMode,
+          udf.inEnc.asInstanceOf[ExpressionEncoder[Any]],
+          outputAttr,
+          ds.analyzed,
+          hasInitialState,
+          ds.groupingAttributes,
+          ds.dataAttributes,
+          keyDeserializer,
+          LocalRelation(ds.vEncoder.schema))
+      }
+      val serializedPlan = SerializeFromObject(udf.outputNamedExpression, 
twsNode)
+
+      if (twsInfo.hasEventTimeColumnName) {
+        val eventTimeWrappedPlan = UpdateEventTimeWatermarkColumn(
+          UnresolvedAttribute(twsInfo.getEventTimeColumnName),
+          None,
+          serializedPlan)
+        eventTimeWrappedPlan
+      } else {
+        serializedPlan
+      }
+    } else if (rel.hasIsMapGroupsWithState) {
       val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && 
rel.hasInitialInput
       val initialDs = if (hasInitialState) {
         UntypedKeyValueGroupedDataset(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to