This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 51b011f64273 [SPARK-50661][CONNECT][SS] Fix Spark Connect Scala
foreachBatch impl. to support Dataset[T].
51b011f64273 is described below
commit 51b011f642730ebcffd6f63d95a632f939e9d474
Author: Haiyang Sun <[email protected]>
AuthorDate: Sat Dec 28 16:20:59 2024 +0900
[SPARK-50661][CONNECT][SS] Fix Spark Connect Scala foreachBatch impl. to
support Dataset[T].
### What changes were proposed in this pull request?
This PR fixes incorrect implementation of Scala Streaming foreachBatch when
the input dataset is not a DataFrame (but a Dataset[T]) in spark connect mode.
**Note** that this only affects `Scala`.
In `DataStreamWriter`:
- serialize foreachBatch function together with the dataset's encoder.
- reuse ForeachWriterPacket for foreachBatch as both are sink operations
and only require a function/writer object and the encoder of the input.
Optionally, we could rename `ForeachWriterPacket` to something more general for
both cases.
In `SparkConnectPlanner` / `StreamingForeachBatchHelper`
- Use the encoder passed from the client to recover the Dataset[T] object
to properly call the foreachBatch function.
### Why are the changes needed?
Without the fix, Scala foreachBatch will fail or give wrong results when
the input dataset is not a DataFrame.
Below is a simple reproduction:
```
import org.apache.spark.sql._
spark.range(10).write.format("parquet").mode("overwrite").save("/tmp/test")
val q = spark.readStream.format("parquet").schema("id
LONG").load("/tmp/test").as[java.lang.Long].writeStream.foreachBatch((ds:
Dataset[java.lang.Long], batchId: Long) =>
println(ds.collect().map(_.asInstanceOf[Long]).sum)).start()
Thread.sleep(1000)
q.stop()
```
The code above should output 45 in the foreachBatch function. Without the
fix, the code will fail because the foreachBatch function will be called with a
DataFrame object instead of Dataset[java.lang.Long].
### Does this PR introduce _any_ user-facing change?
Yes, this PR includes fixes to the Spark Connect client (we add the encoder
to the foreachBatch function during serialization) around the foreachBatch API.
### How was this patch tested?
1. Run end-to-end test with spark-shell (with spark connect server and
client running in connect mode).
2. New / updated unit tests that would have failed without the fix.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49323 from haiyangsun-db/SPARK-50661.
Authored-by: Haiyang Sun <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/streaming/DataStreamWriter.scala | 5 +-
.../sql/streaming/ClientStreamingQuerySuite.scala | 93 ++++++++++++++++++----
.../sql/connect/planner/SparkConnectPlanner.scala | 5 +-
.../planner/StreamingForeachBatchHelper.scala | 27 ++++++-
4 files changed, 108 insertions(+), 22 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 9fcc31e56268..b2c4fcf64e70 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -135,7 +135,10 @@ final class DataStreamWriter[T] private[sql] (ds:
Dataset[T]) extends api.DataSt
/** @inheritdoc */
@Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = {
- val serializedFn = SparkSerDeUtils.serialize(function)
+ // SPARK-50661: the client should send the encoder for the input dataset
together with the
+ // function to the server.
+ val serializedFn =
+ SparkSerDeUtils.serialize(ForeachWriterPacket(function,
ds.agnosticEncoder))
sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
.setPayload(ByteString.copyFrom(serializedFn))
.setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) //
Unused.
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index b1a7d81916e9..199a1507a3b1 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -28,9 +28,8 @@ import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
-import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row,
SparkSession}
import org.apache.spark.sql.functions.{col, lit, udf, window}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent,
QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest,
RemoteSparkSession}
@@ -567,7 +566,7 @@ class ClientStreamingQuerySuite extends QueryTest with
RemoteSparkSession with L
}
}
- test("foreachBatch") {
+ test("foreachBatch with DataFrame") {
// Starts a streaming query with a foreachBatch function, which writes
batchId and row count
// to a temp view. The test verifies that the view is populated with data.
@@ -581,7 +580,12 @@ class ClientStreamingQuerySuite extends QueryTest with
RemoteSparkSession with L
.option("numPartitions", "1")
.load()
.writeStream
- .foreachBatch(new ForeachBatchFn(viewName))
+ .foreachBatch((df: DataFrame, batchId: Long) => {
+ val count = df.collect().map(row => row.getLong(1)).sum
+ df.sparkSession
+ .createDataFrame(Seq((batchId, count)))
+ .createOrReplaceGlobalTempView(viewName)
+ })
.start()
eventually(timeout(30.seconds)) { // Wait for first progress.
@@ -596,6 +600,7 @@ class ClientStreamingQuerySuite extends QueryTest with
RemoteSparkSession with L
.collect()
.toSeq
assert(rows.size > 0)
+ assert(rows.map(_.getLong(1)).sum > 0)
logInfo(s"Rows in $tableName: $rows")
}
@@ -603,6 +608,75 @@ class ClientStreamingQuerySuite extends QueryTest with
RemoteSparkSession with L
}
}
+ test("foreachBatch with Dataset[java.lang.Long]") {
+ val viewName = "test_view"
+ val tableName = s"global_temp.$viewName"
+
+ withTable(tableName) {
+ val session = spark
+ import session.implicits._
+ val q = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("numPartitions", "1")
+ .load()
+ .select($"value")
+ .as[java.lang.Long]
+ .writeStream
+ .foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => {
+ val count = ds.collect().map(v => v.asInstanceOf[Long]).sum
+ ds.sparkSession
+ .createDataFrame(Seq((batchId, count)))
+ .createOrReplaceGlobalTempView(viewName)
+ })
+ .start()
+
+ eventually(timeout(30.seconds)) { // Wait for first progress.
+ assert(q.lastProgress != null, "Failed to make progress")
+ assert(q.lastProgress.numInputRows > 0)
+ }
+
+ eventually(timeout(30.seconds)) {
+ // There should be row(s) in temporary view created by foreachBatch.
+ val rows = spark
+ .sql(s"select * from $tableName")
+ .collect()
+ .toSeq
+ assert(rows.size > 0)
+ assert(rows.map(_.getLong(1)).sum > 0)
+ logInfo(s"Rows in $tableName: $rows")
+ }
+
+ q.stop()
+ }
+ }
+
+ test("foreachBatch with Dataset[TestClass]") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val viewName = "test_view"
+ val tableName = s"global_temp.$viewName"
+
+ val df = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .load()
+
+ val q = df
+ .selectExpr("CAST(value AS INT)")
+ .as[TestClass]
+ .writeStream
+ .foreachBatch((ds: Dataset[TestClass], batchId: Long) => {
+ val count = ds.collect().map(_.value).sum
+ })
+ .start()
+ eventually(timeout(30.seconds)) {
+ assert(q.isActive)
+ assert(q.exception.isEmpty)
+ }
+ q.stop()
+ }
+
abstract class EventCollector extends StreamingQueryListener {
protected def tablePostfix: String
@@ -700,14 +774,3 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
case class TestClass(value: Int) {
override def toString: String = value.toString
}
-
-class ForeachBatchFn(val viewName: String)
- extends VoidFunction2[DataFrame, java.lang.Long]
- with Serializable {
- override def call(df: DataFrame, batchId: java.lang.Long): Unit = {
- val count = df.count()
- df.sparkSession
- .createDataFrame(Seq((batchId.toLong, count)))
- .createOrReplaceGlobalTempView(viewName)
- }
-}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5ace916ba3e9..d6ade1ac9126 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2957,10 +2957,9 @@ class SparkConnectPlanner(
fn
case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
- val scalaFn =
Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
+ StreamingForeachBatchHelper.scalaForeachBatchWrapper(
writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
- Utils.getContextOrSparkClassLoader)
- StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn,
sessionHolder)
+ sessionHolder)
case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected foreachBatch function") //
Unreachable
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index b6f67fe9f02f..ab6bed7152c0 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -28,11 +28,14 @@ import org.apache.spark.SparkException
import org.apache.spark.api.python.{PythonException, PythonWorkerUtils,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID,
RUN_ID_STRING, SESSION_ID}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
AgnosticEncoders}
+import org.apache.spark.sql.connect.common.ForeachWriterPacket
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.streaming.StreamingQueryListener
+import org.apache.spark.util.Utils
/**
* A helper class for handling ForeachBatch related functionality in Spark
Connect servers
@@ -88,13 +91,31 @@ object StreamingForeachBatchHelper extends Logging {
* DataFrame, so the user code actually runs with legacy DataFrame and
session..
*/
def scalaForeachBatchWrapper(
- fn: ForeachBatchFnType,
+ payloadBytes: Array[Byte],
sessionHolder: SessionHolder): ForeachBatchFnType = {
+ val foreachBatchPkt =
+ Utils.deserialize[ForeachWriterPacket](payloadBytes,
Utils.getContextOrSparkClassLoader)
+ val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long)
=> Unit]
+ val encoder =
foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]]
// TODO(SPARK-44462): Set up Spark Connect session.
// Do we actually need this for the first version?
dataFrameCachingWrapper(
(args: FnArgsWithId) => {
- fn(args.df, args.batchId) // dfId is not used, see hack comment above.
+ // dfId is not used, see hack comment above.
+ try {
+ val ds = if (AgnosticEncoders.UnboundRowEncoder == encoder) {
+ // When the dataset is a DataFrame (Dataset[Row).
+ args.df.asInstanceOf[Dataset[Any]]
+ } else {
+ // Recover the Dataset from the DataFrame using the encoder.
+ Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder)
+ }
+ fn(ds, args.batchId)
+ } catch {
+ case t: Throwable =>
+ logError(s"Calling foreachBatch fn failed", t)
+ throw t
+ }
},
sessionHolder)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]