This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 51b011f64273 [SPARK-50661][CONNECT][SS] Fix Spark Connect Scala 
foreachBatch impl. to support Dataset[T].
51b011f64273 is described below

commit 51b011f642730ebcffd6f63d95a632f939e9d474
Author: Haiyang Sun <[email protected]>
AuthorDate: Sat Dec 28 16:20:59 2024 +0900

    [SPARK-50661][CONNECT][SS] Fix Spark Connect Scala foreachBatch impl. to 
support Dataset[T].
    
    ### What changes were proposed in this pull request?
    This PR fixes incorrect implementation of Scala Streaming foreachBatch when 
the input dataset is not a DataFrame (but a Dataset[T]) in spark connect mode.
    
    **Note** that this only affects `Scala`.
    
    In `DataStreamWriter`:
    - serialize foreachBatch function together with the dataset's encoder.
    - reuse ForeachWriterPacket for foreachBatch as both are sink operations 
and only require a function/writer object and the encoder of the input. 
Optionally, we could rename `ForeachWriterPacket` to something more general for 
both cases.
    
    In `SparkConnectPlanner` / `StreamingForeachBatchHelper`
    - Use the encoder passed from the client to recover the Dataset[T] object 
to properly call the foreachBatch function.
    
    ### Why are the changes needed?
    Without the fix, Scala foreachBatch will fail or give wrong results when 
the input dataset is not a DataFrame.
    
    Below is a simple reproduction:
    
    ```
    import org.apache.spark.sql._
    spark.range(10).write.format("parquet").mode("overwrite").save("/tmp/test")
    
    val q = spark.readStream.format("parquet").schema("id 
LONG").load("/tmp/test").as[java.lang.Long].writeStream.foreachBatch((ds: 
Dataset[java.lang.Long], batchId: Long) => 
println(ds.collect().map(_.asInstanceOf[Long]).sum)).start()
    
    Thread.sleep(1000)
    q.stop()
    ```
    
    The code above should output 45 in the foreachBatch function. Without the 
fix, the code will fail because the foreachBatch function will be called with a 
DataFrame object instead of Dataset[java.lang.Long].
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, this PR includes fixes to the Spark Connect client (we add the encoder 
to the foreachBatch function during serialization) around the foreachBatch API.
    
    ### How was this patch tested?
    1. Run end-to-end test with spark-shell (with spark connect server and 
client running in connect mode).
    2. New / updated unit tests that would have failed without the fix.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49323 from haiyangsun-db/SPARK-50661.
    
    Authored-by: Haiyang Sun <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/streaming/DataStreamWriter.scala     |  5 +-
 .../sql/streaming/ClientStreamingQuerySuite.scala  | 93 ++++++++++++++++++----
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 +-
 .../planner/StreamingForeachBatchHelper.scala      | 27 ++++++-
 4 files changed, 108 insertions(+), 22 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 9fcc31e56268..b2c4fcf64e70 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -135,7 +135,10 @@ final class DataStreamWriter[T] private[sql] (ds: 
Dataset[T]) extends api.DataSt
   /** @inheritdoc */
   @Evolving
   def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = {
-    val serializedFn = SparkSerDeUtils.serialize(function)
+    // SPARK-50661: the client should send the encoder for the input dataset 
together with the
+    //  function to the server.
+    val serializedFn =
+      SparkSerDeUtils.serialize(ForeachWriterPacket(function, 
ds.agnosticEncoder))
     sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder
       .setPayload(ByteString.copyFrom(serializedFn))
       .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // 
Unused.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index b1a7d81916e9..199a1507a3b1 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -28,9 +28,8 @@ import org.scalatest.concurrent.Futures.timeout
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkException
-import org.apache.spark.api.java.function.VoidFunction2
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row, 
SparkSession}
 import org.apache.spark.sql.functions.{col, lit, udf, window}
 import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, 
QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
 import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, 
RemoteSparkSession}
@@ -567,7 +566,7 @@ class ClientStreamingQuerySuite extends QueryTest with 
RemoteSparkSession with L
     }
   }
 
-  test("foreachBatch") {
+  test("foreachBatch with DataFrame") {
     // Starts a streaming query with a foreachBatch function, which writes 
batchId and row count
     // to a temp view. The test verifies that the view is populated with data.
 
@@ -581,7 +580,12 @@ class ClientStreamingQuerySuite extends QueryTest with 
RemoteSparkSession with L
         .option("numPartitions", "1")
         .load()
         .writeStream
-        .foreachBatch(new ForeachBatchFn(viewName))
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          val count = df.collect().map(row => row.getLong(1)).sum
+          df.sparkSession
+            .createDataFrame(Seq((batchId, count)))
+            .createOrReplaceGlobalTempView(viewName)
+        })
         .start()
 
       eventually(timeout(30.seconds)) { // Wait for first progress.
@@ -596,6 +600,7 @@ class ClientStreamingQuerySuite extends QueryTest with 
RemoteSparkSession with L
           .collect()
           .toSeq
         assert(rows.size > 0)
+        assert(rows.map(_.getLong(1)).sum > 0)
         logInfo(s"Rows in $tableName: $rows")
       }
 
@@ -603,6 +608,75 @@ class ClientStreamingQuerySuite extends QueryTest with 
RemoteSparkSession with L
     }
   }
 
+  test("foreachBatch with Dataset[java.lang.Long]") {
+    val viewName = "test_view"
+    val tableName = s"global_temp.$viewName"
+
+    withTable(tableName) {
+      val session = spark
+      import session.implicits._
+      val q = spark.readStream
+        .format("rate")
+        .option("rowsPerSecond", "10")
+        .option("numPartitions", "1")
+        .load()
+        .select($"value")
+        .as[java.lang.Long]
+        .writeStream
+        .foreachBatch((ds: Dataset[java.lang.Long], batchId: Long) => {
+          val count = ds.collect().map(v => v.asInstanceOf[Long]).sum
+          ds.sparkSession
+            .createDataFrame(Seq((batchId, count)))
+            .createOrReplaceGlobalTempView(viewName)
+        })
+        .start()
+
+      eventually(timeout(30.seconds)) { // Wait for first progress.
+        assert(q.lastProgress != null, "Failed to make progress")
+        assert(q.lastProgress.numInputRows > 0)
+      }
+
+      eventually(timeout(30.seconds)) {
+        // There should be row(s) in temporary view created by foreachBatch.
+        val rows = spark
+          .sql(s"select * from $tableName")
+          .collect()
+          .toSeq
+        assert(rows.size > 0)
+        assert(rows.map(_.getLong(1)).sum > 0)
+        logInfo(s"Rows in $tableName: $rows")
+      }
+
+      q.stop()
+    }
+  }
+
+  test("foreachBatch with Dataset[TestClass]") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val viewName = "test_view"
+    val tableName = s"global_temp.$viewName"
+
+    val df = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", "10")
+      .load()
+
+    val q = df
+      .selectExpr("CAST(value AS INT)")
+      .as[TestClass]
+      .writeStream
+      .foreachBatch((ds: Dataset[TestClass], batchId: Long) => {
+        val count = ds.collect().map(_.value).sum
+      })
+      .start()
+    eventually(timeout(30.seconds)) {
+      assert(q.isActive)
+      assert(q.exception.isEmpty)
+    }
+    q.stop()
+  }
+
   abstract class EventCollector extends StreamingQueryListener {
     protected def tablePostfix: String
 
@@ -700,14 +774,3 @@ class TestForeachWriter[T] extends ForeachWriter[T] {
 case class TestClass(value: Int) {
   override def toString: String = value.toString
 }
-
-class ForeachBatchFn(val viewName: String)
-    extends VoidFunction2[DataFrame, java.lang.Long]
-    with Serializable {
-  override def call(df: DataFrame, batchId: java.lang.Long): Unit = {
-    val count = df.count()
-    df.sparkSession
-      .createDataFrame(Seq((batchId.toLong, count)))
-      .createOrReplaceGlobalTempView(viewName)
-  }
-}
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5ace916ba3e9..d6ade1ac9126 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2957,10 +2957,9 @@ class SparkConnectPlanner(
           fn
 
         case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
-          val scalaFn = 
Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
+          StreamingForeachBatchHelper.scalaForeachBatchWrapper(
             writeOp.getForeachBatch.getScalaFunction.getPayload.toByteArray,
-            Utils.getContextOrSparkClassLoader)
-          StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, 
sessionHolder)
+            sessionHolder)
 
         case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
           throw InvalidPlanInput("Unexpected foreachBatch function") // 
Unreachable
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index b6f67fe9f02f..ab6bed7152c0 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -28,11 +28,14 @@ import org.apache.spark.SparkException
 import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, QUERY_ID, 
RUN_ID_STRING, SESSION_ID}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
AgnosticEncoders}
+import org.apache.spark.sql.connect.common.ForeachWriterPacket
 import org.apache.spark.sql.connect.service.SessionHolder
 import org.apache.spark.sql.connect.service.SparkConnectService
 import org.apache.spark.sql.streaming.StreamingQuery
 import org.apache.spark.sql.streaming.StreamingQueryListener
+import org.apache.spark.util.Utils
 
 /**
  * A helper class for handling ForeachBatch related functionality in Spark 
Connect servers
@@ -88,13 +91,31 @@ object StreamingForeachBatchHelper extends Logging {
    * DataFrame, so the user code actually runs with legacy DataFrame and 
session..
    */
   def scalaForeachBatchWrapper(
-      fn: ForeachBatchFnType,
+      payloadBytes: Array[Byte],
       sessionHolder: SessionHolder): ForeachBatchFnType = {
+    val foreachBatchPkt =
+      Utils.deserialize[ForeachWriterPacket](payloadBytes, 
Utils.getContextOrSparkClassLoader)
+    val fn = foreachBatchPkt.foreachWriter.asInstanceOf[(Dataset[Any], Long) 
=> Unit]
+    val encoder = 
foreachBatchPkt.datasetEncoder.asInstanceOf[AgnosticEncoder[Any]]
     // TODO(SPARK-44462): Set up Spark Connect session.
     // Do we actually need this for the first version?
     dataFrameCachingWrapper(
       (args: FnArgsWithId) => {
-        fn(args.df, args.batchId) // dfId is not used, see hack comment above.
+        // dfId is not used, see hack comment above.
+        try {
+          val ds = if (AgnosticEncoders.UnboundRowEncoder == encoder) {
+            // When the dataset is a DataFrame (Dataset[Row).
+            args.df.asInstanceOf[Dataset[Any]]
+          } else {
+            // Recover the Dataset from the DataFrame using the encoder.
+            Dataset.apply(args.df.sparkSession, args.df.logicalPlan)(encoder)
+          }
+          fn(ds, args.batchId)
+        } catch {
+          case t: Throwable =>
+            logError(s"Calling foreachBatch fn failed", t)
+            throw t
+        }
       },
       sessionHolder)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to