This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3820ae0f86b9 [SPARK-50855][SS][CONNECT] Spark Connect Support for TransformWithState In Scala 3820ae0f86b9 is described below commit 3820ae0f86b9935061e5eecade6e07f628a805ad Author: jingz-db <jing.z...@databricks.com> AuthorDate: Tue Mar 4 14:23:55 2025 +0900 [SPARK-50855][SS][CONNECT] Spark Connect Support for TransformWithState In Scala ### What changes were proposed in this pull request? Add Spark connect support for TransformWithState. ### Why are the changes needed? We need to implement our newly developed operator TransformWithState in connect for the feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? In `TransformWithStateStreamingSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49488 from jingz-db/tws-connect-latest. Lead-authored-by: jingz-db <jing.z...@databricks.com> Co-authored-by: Jing Zhan <135738831+jingz...@users.noreply.github.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../sql/catalyst/plans/logical/TimeMode.scala | 10 +- .../streaming/TransformWithStateConnectSuite.scala | 512 +++++++++++++++++++++ .../spark/sql/connect/KeyValueGroupedDataset.scala | 76 ++- .../sql/connect/planner/SparkConnectPlanner.scala | 70 ++- 4 files changed, 658 insertions(+), 10 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala index e870a83ec4ae..da454c1c4214 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala @@ -21,7 +21,13 @@ import java.util.Locale import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.streaming.TimeMode -/** TimeMode types used in transformWithState operator */ +/** + * TimeMode types used in transformWithState operator + * + * Note that we need to keep TimeMode.None() named as "NoTime" in case class here because a case + * class named "None" will introduce naming collision with scala native type None. See SPARK-51151 + * for more info. + */ case object NoTime extends TimeMode case object ProcessingTime extends TimeMode @@ -31,7 +37,7 @@ case object EventTime extends TimeMode object TimeModes { def apply(timeMode: String): TimeMode = { timeMode.toLowerCase(Locale.ROOT) match { - case "none" => + case "none" | "notime" => NoTime case "processingtime" => ProcessingTime diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala new file mode 100644 index 000000000000..a50f60188a45 --- /dev/null +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/TransformWithStateConnectSuite.scala @@ -0,0 +1,512 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.streaming + +import java.io.{BufferedWriter, File, FileWriter} +import java.nio.file.Paths +import java.sql.Timestamp + +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row} +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{ListState, MapState, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode, TimerValues, TTLConfig, ValueState} +import org.apache.spark.sql.types._ + +case class InputRowForConnectTest(key: String, value: String) +case class OutputRowForConnectTest(key: String, value: String) +case class StateRowForConnectTest(count: Long) + +// A basic stateful processor which will return the occurrences of key +class BasicCountStatefulProcessor + extends StatefulProcessor[String, InputRowForConnectTest, OutputRowForConnectTest] + with Logging { + @transient protected var _countState: ValueState[StateRowForConnectTest] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[StateRowForConnectTest]( + "countState", + Encoders.product[StateRowForConnectTest], + TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[InputRowForConnectTest], + timerValues: TimerValues): Iterator[OutputRowForConnectTest] = { + val count = inputRows.toSeq.length + { + if (_countState.exists()) { + _countState.get().count + } else { + 0L + } + } + _countState.update(StateRowForConnectTest(count)) + Iterator(OutputRowForConnectTest(key, count.toString)) + } +} + +// A stateful processor with initial state which will return the occurrences of key +class TestInitialStatefulProcessor + extends StatefulProcessorWithInitialState[ + String, + (String, String), + (String, String), + (String, String, String)] + with Logging { + @transient protected var _countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, TTLConfig.NONE) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues): Iterator[(String, String)] = { + val count = inputRows.toSeq.length + { + if (_countState.exists()) { + _countState.get() + } else { + 0L + } + } + _countState.update(count) + Iterator((key, count.toString)) + } + + override def handleInitialState( + key: String, + initialState: (String, String, String), + timerValues: TimerValues): Unit = { + val count = 1 + { + if (_countState.exists()) { + _countState.get() + } else { + 0L + } + } + _countState.update(count) + } +} + +case class OutputEventTimeRow(key: String, outputTimestamp: Timestamp) + +// A stateful processor which will return timestamp of the first item from input rows +class ChainingOfOpsStatefulProcessor + extends StatefulProcessor[String, (String, Timestamp), OutputEventTimeRow] { + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {} + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, Timestamp)], + timerValues: TimerValues): Iterator[OutputEventTimeRow] = { + val timestamp = inputRows.next()._2 + Iterator(OutputEventTimeRow(key, timestamp)) + } +} + +// A basic stateful processor contains composite state variables and TTL +class TTLTestStatefulProcessor + extends StatefulProcessor[String, (String, String), (String, String)] { + import java.time.Duration + + @transient protected var countState: ValueState[Int] = _ + @transient protected var ttlCountState: ValueState[Int] = _ + @transient protected var ttlListState: ListState[Int] = _ + @transient protected var ttlMapState: MapState[String, Int] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt, TTLConfig.NONE) + ttlCountState = getHandle + .getValueState[Int]("ttlCountState", Encoders.scalaInt, TTLConfig(Duration.ofMillis(1000))) + ttlListState = getHandle + .getListState[Int]("ttlListState", Encoders.scalaInt, TTLConfig(Duration.ofMillis(1000))) + ttlMapState = getHandle.getMapState[String, Int]( + "ttlMapState", + Encoders.STRING, + Encoders.scalaInt, + TTLConfig(Duration.ofMillis(1000))) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[(String, String)], + timerValues: TimerValues): Iterator[(String, String)] = { + val numOfInputRows = inputRows.toSeq.length + var count = numOfInputRows + var ttlCount = numOfInputRows + var ttlListStateCount = numOfInputRows + var ttlMapStateCount = numOfInputRows + + if (countState.exists()) { + count += countState.get() + } + if (ttlCountState.exists()) { + ttlCount += ttlCountState.get() + } + if (ttlListState.exists()) { + for (value <- ttlListState.get()) { + ttlListStateCount += value + } + } + if (ttlMapState.exists()) { + ttlMapStateCount = ttlMapState.getValue(key) + } + countState.update(count) + if (key != "0") { + ttlCountState.update(ttlCount) + ttlListState.put(Array(ttlListStateCount, ttlListStateCount)) + ttlMapState.updateValue(key, ttlMapStateCount) + } + val output = List( + (s"count-$key", count.toString), + (s"ttlCount-$key", ttlCount.toString), + (s"ttlListState-$key", ttlListStateCount.toString), + (s"ttlMapState-$key", ttlMapStateCount.toString)) + output.iterator + } +} + +class TransformWithStateConnectSuite extends QueryTest with RemoteSparkSession with Logging { + val testData: Seq[(String, String)] = Seq(("a", "1"), ("b", "1"), ("a", "2")) + val twsAdditionalSQLConf = Seq( + "spark.sql.streaming.stateStore.providerClass" -> + "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", + "spark.sql.shuffle.partitions" -> "5", + "spark.sql.session.timeZone" -> "UTC", + "spark.sql.streaming.noDataMicroBatches.enabled" -> "false") + + test("transformWithState - streaming with state variable, case class type") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + testData + .toDS() + .toDF("key", "value") + .repartition(3) + .write + .parquet(path) + + val testSchema = + StructType(Array(StructField("key", StringType), StructField("value", StringType))) + + val q = spark.readStream + .schema(testSchema) + .option("maxFilesPerTrigger", 1) + .parquet(path) + .as[InputRowForConnectTest] + .groupByKey(x => x.key) + .transformWithState[OutputRowForConnectTest]( + new BasicCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + .writeStream + .format("memory") + .queryName("my_sink") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDatasetUnorderly( + spark.table("my_sink").toDF().as[(String, String)], + ("a", "1"), + ("a", "2"), + ("b", "1")) + } + } finally { + q.stop() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + } + } + } + + test("transformWithState - streaming with initial state") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + testData + .toDS() + .toDF("key", "value") + .repartition(3) + .write + .parquet(path) + + val testSchema = + StructType(Array(StructField("key", StringType), StructField("value", StringType))) + + val initDf = Seq(("init_1", "40.0", "a"), ("init_2", "100.0", "b")) + .toDS() + .groupByKey(x => x._3) + .mapValues(x => x) + + val q = spark.readStream + .schema(testSchema) + .option("maxFilesPerTrigger", 1) + .parquet(path) + .as[(String, String)] + .groupByKey(x => x._1) + .transformWithState( + new TestInitialStatefulProcessor(), + TimeMode.None(), + OutputMode.Update(), + initialState = initDf) + .writeStream + .format("memory") + .queryName("my_sink") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDatasetUnorderly( + spark.table("my_sink").toDF().as[(String, String)], + ("a", "2"), + ("a", "3"), + ("b", "2")) + } + } finally { + q.stop() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + } + } + } + + test("transformWithState - streaming with chaining of operators") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + def timestamp(num: Int): Timestamp = { + new Timestamp(num * 1000) + } + + val checkResultFunc: (Dataset[Row], Long) => Unit = { (batchDF, batchId) => + val realDf = batchDF.collect().toSet + if (batchId == 0) { + assert(realDf.isEmpty, s"BatchId: $batchId, RealDF: $realDf") + } else if (batchId == 1) { + // eviction watermark = 15 - 5 = 10 (max event time from batch 0), + // late event watermark = 0 (eviction event time from batch 0) + val expectedDF = Seq(Row(timestamp(10), 1L)).toSet + assert( + realDf == expectedDF, + s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf") + } else if (batchId == 2) { + // eviction watermark = 25 - 5 = 20, late event watermark = 10; + // row with watermark=5<10 is dropped so it does not show up in the results; + // row with eventTime<=20 are finalized and emitted + val expectedDF = Seq(Row(timestamp(11), 1L), Row(timestamp(15), 1L)).toSet + assert( + realDf == expectedDF, + s"BatchId: $batchId, expectedDf: $expectedDF, RealDF: $realDf") + } + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val curTime = System.currentTimeMillis + val file1 = prepareInputData(path + "/text-test3.csv", Seq("a", "b"), Seq(10, 15)) + file1.setLastModified(curTime + 2L) + val file2 = prepareInputData(path + "/text-test4.csv", Seq("a", "c"), Seq(11, 25)) + file2.setLastModified(curTime + 4L) + val file3 = prepareInputData(path + "/text-test1.csv", Seq("a"), Seq(5)) + file3.setLastModified(curTime + 6L) + + val q = buildTestDf(path, spark) + .select(col("key").as("key"), timestamp_seconds(col("value")).as("eventTime")) + .withWatermark("eventTime", "5 seconds") + .as[(String, Timestamp)] + .groupByKey(x => x._1) + .transformWithState[OutputEventTimeRow]( + new ChainingOfOpsStatefulProcessor(), + "outputTimestamp", + OutputMode.Append()) + .groupBy("outputTimestamp") + .count() + .writeStream + .foreachBatch(checkResultFunc) + .outputMode("Append") + .start() + + q.processAllAvailable() + eventually(timeout(30.seconds)) { + q.stop() + } + } + } + } + + test("transformWithState - streaming with TTL and composite state variables") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + val checkResultFunc = (batchDF: Dataset[(String, String)], batchId: Long) => { + if (batchId == 0) { + val expectedDF = Set( + ("count-0", "1"), + ("ttlCount-0", "1"), + ("ttlListState-0", "1"), + ("ttlMapState-0", "1"), + ("count-1", "1"), + ("ttlCount-1", "1"), + ("ttlListState-1", "1"), + ("ttlMapState-1", "1")) + + val realDf = batchDF.collect().toSet + assert(realDf == expectedDF) + + } else if (batchId == 1) { + val expectedDF = Set( + ("count-0", "2"), + ("ttlCount-0", "1"), + ("ttlListState-0", "1"), + ("ttlMapState-0", "1"), + ("count-1", "2"), + ("ttlCount-1", "1"), + ("ttlListState-1", "1"), + ("ttlMapState-1", "1")) + + val realDf = batchDF.collect().toSet + assert(realDf == expectedDF) + } + + if (batchId == 0) { + // let ttl state expires + Thread.sleep(2000) + } + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val curTime = System.currentTimeMillis + val file1 = prepareInputData(path + "/text-test3.csv", Seq("1", "0"), Seq(0, 0)) + file1.setLastModified(curTime + 2L) + val file2 = prepareInputData(path + "/text-test4.csv", Seq("1", "0"), Seq(0, 0)) + file2.setLastModified(curTime + 4L) + + val q = buildTestDf(path, spark) + .as[(String, String)] + .groupByKey(x => x._1) + .transformWithState( + new TTLTestStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + .writeStream + .foreachBatch(checkResultFunc) + .outputMode("Update") + .start() + q.processAllAvailable() + + eventually(timeout(30.seconds)) { + q.stop() + } + } + } + } + + test("transformWithState - batch query") { + withSQLConf(twsAdditionalSQLConf: _*) { + val session: SparkSession = spark + import session.implicits._ + + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + testData + .toDS() + .toDF("key", "value") + .repartition(3) + .write + .parquet(path) + + val testSchema = + StructType(Array(StructField("key", StringType), StructField("value", StringType))) + + spark.read + .schema(testSchema) + .parquet(path) + .as[InputRowForConnectTest] + .groupByKey(x => x.key) + .transformWithState[OutputRowForConnectTest]( + new BasicCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + .write + .saveAsTable("my_sink") + + checkDatasetUnorderly( + spark.table("my_sink").toDF().as[(String, String)], + ("a", "2"), + ("b", "1")) + } + } + } + + /* Utils functions for tests */ + def prepareInputData(inputPath: String, col1: Seq[String], col2: Seq[Int]): File = { + // Ensure the parent directory exists + val file = Paths.get(inputPath).toFile + val parentDir = file.getParentFile + if (parentDir != null && !parentDir.exists()) { + parentDir.mkdirs() + } + + val writer = new BufferedWriter(new FileWriter(inputPath)) + try { + col1.zip(col2).foreach { case (e1, e2) => + writer.write(s"$e1, $e2\n") + } + } finally { + writer.close() + } + file + } + + def buildTestDf(inputPath: String, sparkSession: SparkSession): DataFrame = { + sparkSession.readStream + .format("csv") + .schema( + new StructType() + .add(StructField("key", StringType)) + .add(StructField("value", StringType))) + .option("maxFilesPerTrigger", 1) + .load(inputPath) + .select(col("key").as("key"), col("value").cast("integer")) + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala index b15e8c28df74..090907a538c7 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala @@ -141,7 +141,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDa statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, outputMode: OutputMode): Dataset[U] = - unsupported() + transformWithStateHelper(statefulProcessor, timeMode, outputMode) /** @inheritdoc */ private[sql] def transformWithState[U: Encoder, S: Encoder]( @@ -149,20 +149,40 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends sql.KeyValueGroupedDa timeMode: TimeMode, outputMode: OutputMode, initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = - unsupported() + transformWithStateHelper(statefulProcessor, timeMode, outputMode, Some(initialState)) /** @inheritdoc */ override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], eventTimeColumnName: String, - outputMode: OutputMode): Dataset[U] = unsupported() + outputMode: OutputMode): Dataset[U] = + transformWithStateHelper( + statefulProcessor, + TimeMode.EventTime(), + outputMode, + eventTimeColumnName = eventTimeColumnName) /** @inheritdoc */ override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, outputMode: OutputMode, - initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() + initialState: sql.KeyValueGroupedDataset[K, S]): Dataset[U] = + transformWithStateHelper( + statefulProcessor, + TimeMode.EventTime(), + outputMode, + Some(initialState), + eventTimeColumnName) + + // This is an interface, and it should not be used. The real implementation is in the + // inherited class. + protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None, + eventTimeColumnName: String = ""): Dataset[U] = unsupported() // Overrides... /** @inheritdoc */ @@ -602,7 +622,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } val initialStateImpl = if (initialState.isDefined) { - assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] } else { null @@ -632,6 +651,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } } + override protected[sql] def transformWithStateHelper[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: Option[sql.KeyValueGroupedDataset[K, S]] = None, + eventTimeColumnName: String = ""): Dataset[U] = { + val outputEncoder = agnosticEncoderFor[U] + val stateEncoder = agnosticEncoderFor[S] + val inputEncoders: Seq[AgnosticEncoder[_]] = Seq(kEncoder, stateEncoder, ivEncoder) + + // SparkUserDefinedFunction is creating a udfPacket where the input function are + // being java serialized into bytes; we pass in `statefulProcessor` as function so it can be + // serialized into bytes and deserialized back on connect server + val sparkUserDefinedFunc = + SparkUserDefinedFunction(statefulProcessor, inputEncoders, outputEncoder) + val funcProto = UdfToProtoUtils.toProto(sparkUserDefinedFunc) + + val initialStateImpl = if (initialState.isDefined) { + initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] + } else { + null + } + + sparkSession.newDataset[U](outputEncoder) { builder => + val twsBuilder = builder.getGroupMapBuilder + val twsInfoBuilder = proto.TransformWithStateInfo.newBuilder() + if (!eventTimeColumnName.isEmpty) { + twsInfoBuilder.setEventTimeColumnName(eventTimeColumnName) + } + twsBuilder + .setInput(plan.getRoot) + .addAllGroupingExpressions(groupingExprs) + .setFunc(funcProto) + .setOutputMode(outputMode.toString) + .setTransformWithStateInfo( + twsInfoBuilder + // we pass time mode as string here and deterministically restored on server + .setTimeMode(timeMode.toString) + .build()) + if (initialStateImpl != null) { + twsBuilder + .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs) + .setInitialInput(initialStateImpl.plan.getRoot) + } + } + } + private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])( inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = { val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 24fc1275d482..734eb394ca68 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -54,7 +54,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint, UpdateAction, UpdateSt [...] +import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TimeModes, TransformWithState, TypedFilter, Union, Unpivot, Unresol [...] import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -81,7 +81,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{ArrowUtils, CaseInsensitiveStringMap} import org.apache.spark.storage.CacheId @@ -684,7 +684,71 @@ class SparkConnectPlanner( rel.getGroupingExpressionsList, rel.getSortingExpressionsList) - if (rel.hasIsMapGroupsWithState) { + if (rel.hasTransformWithStateInfo) { + val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput + + val twsInfo = rel.getTransformWithStateInfo + val keyDeserializer = udf.inputDeserializer(ds.groupingAttributes) + val outputAttr = udf.outputObjAttr + + val timeMode = TimeModes(twsInfo.getTimeMode) + val outputMode = InternalOutputModes(rel.getOutputMode) + + val twsNode = if (hasInitialState) { + val statefulProcessor = unpackedUdf.function + .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]] + val initDs = UntypedKeyValueGroupedDataset( + rel.getInitialInput, + rel.getInitialGroupingExpressionsList, + rel.getSortingExpressionsList) + new TransformWithState( + keyDeserializer, + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + statefulProcessor, + timeMode, + outputMode, + udf.inEnc.asInstanceOf[ExpressionEncoder[Any]], + outputAttr, + ds.analyzed, + hasInitialState, + initDs.groupingAttributes, + initDs.dataAttributes, + initDs.valueDeserializer, + initDs.analyzed) + } else { + val statefulProcessor = + unpackedUdf.function.asInstanceOf[StatefulProcessor[Any, Any, Any]] + new TransformWithState( + keyDeserializer, + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + statefulProcessor, + timeMode, + outputMode, + udf.inEnc.asInstanceOf[ExpressionEncoder[Any]], + outputAttr, + ds.analyzed, + hasInitialState, + ds.groupingAttributes, + ds.dataAttributes, + keyDeserializer, + LocalRelation(ds.vEncoder.schema)) + } + val serializedPlan = SerializeFromObject(udf.outputNamedExpression, twsNode) + + if (twsInfo.hasEventTimeColumnName) { + val eventTimeWrappedPlan = UpdateEventTimeWatermarkColumn( + UnresolvedAttribute(twsInfo.getEventTimeColumnName), + None, + serializedPlan) + eventTimeWrappedPlan + } else { + serializedPlan + } + } else if (rel.hasIsMapGroupsWithState) { val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput val initialDs = if (hasInitialState) { UntypedKeyValueGroupedDataset( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org