This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 25a11c4eadef Revert "[SPARK-53525][CONNECT][FOLLOWUP] Spark Connect 
ArrowBatch Result Chunking - Scala Client"
25a11c4eadef is described below

commit 25a11c4eadef3cf83bd0bade11fbdfae24f57c40
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Wed Nov 5 21:39:08 2025 -0800

    Revert "[SPARK-53525][CONNECT][FOLLOWUP] Spark Connect ArrowBatch Result 
Chunking - Scala Client"
    
    This reverts commit bc0f6f7a8a0db3d56b106f23956aa6e6e999d99d.
---
 .../spark/sql/connect/ClientE2ETestSuite.scala     | 159 +--------------------
 .../sql/connect/test/RemoteSparkSession.scala      |  26 ++--
 .../sql/connect/client/SparkConnectClient.scala    |  45 +-----
 .../spark/sql/connect/client/SparkResult.scala     | 138 ++++++------------
 4 files changed, 57 insertions(+), 311 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 450ff8ca6249..8c336b6fa6d5 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -26,14 +26,12 @@ import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration.{DurationInt, FiniteDuration}
 import scala.jdk.CollectionConverters._
 
-import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, 
ForwardingClientCall, ForwardingClientCallListener, MethodDescriptor}
 import org.apache.commons.io.output.TeeOutputStream
 import org.scalactic.TolerantNumerics
 import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.{SparkArithmeticException, SparkException, 
SparkUpgradeException}
 import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
-import org.apache.spark.connect.proto
 import org.apache.spark.internal.config.ConfigBuilder
 import org.apache.spark.sql.{functions, AnalysisException, Observation, Row, 
SaveMode}
 import 
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, 
NoSuchNamespaceException, TableAlreadyExistsException, 
TempTableAlreadyExistsException}
@@ -43,7 +41,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.connect.ConnectConversions._
 import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, 
SparkResult}
 import org.apache.spark.sql.connect.test.{ConnectFunSuite, 
IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper}
-import 
org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession, 
port}
+import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SqlApiConf
 import org.apache.spark.sql.types._
@@ -1850,161 +1848,6 @@ class ClientE2ETestSuite
       checkAnswer(df, Seq.empty)
     }
   }
-
-  // Helper class to capture Arrow batch chunk information from gRPC responses
-  private class ArrowBatchInterceptor extends ClientInterceptor {
-    case class BatchInfo(
-        batchIndex: Int,
-        rowCount: Long,
-        startOffset: Long,
-        chunks: Seq[ChunkInfo]) {
-      def totalChunks: Int = chunks.length
-    }
-
-    case class ChunkInfo(
-        batchIndex: Int,
-        chunkIndex: Int,
-        numChunksInBatch: Int,
-        rowCount: Long,
-        startOffset: Long,
-        dataSize: Int)
-
-    private val batches: mutable.Buffer[BatchInfo] = mutable.Buffer.empty
-    private var currentBatchIndex: Int = 0
-    private val currentBatchChunks: mutable.Buffer[ChunkInfo] = 
mutable.Buffer.empty
-
-    override def interceptCall[ReqT, RespT](
-        method: MethodDescriptor[ReqT, RespT],
-        callOptions: CallOptions,
-        next: Channel): ClientCall[ReqT, RespT] = {
-      new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
-        next.newCall(method, callOptions)) {
-        override def start(
-            responseListener: ClientCall.Listener[RespT],
-            headers: io.grpc.Metadata): Unit = {
-          super.start(
-            new 
ForwardingClientCallListener.SimpleForwardingClientCallListener[RespT](
-              responseListener) {
-              override def onMessage(message: RespT): Unit = {
-                message match {
-                  case response: proto.ExecutePlanResponse if 
response.hasArrowBatch =>
-                    val arrowBatch = response.getArrowBatch
-                    // Track chunk information for every chunk
-                    currentBatchChunks += ChunkInfo(
-                      batchIndex = currentBatchIndex,
-                      chunkIndex = arrowBatch.getChunkIndex.toInt,
-                      numChunksInBatch = arrowBatch.getNumChunksInBatch.toInt,
-                      rowCount = arrowBatch.getRowCount,
-                      startOffset = arrowBatch.getStartOffset,
-                      dataSize = arrowBatch.getData.size())
-                    // When we receive the last chunk, create the BatchInfo
-                    if (currentBatchChunks.length == 
arrowBatch.getNumChunksInBatch) {
-                      batches += BatchInfo(
-                        batchIndex = currentBatchIndex,
-                        rowCount = arrowBatch.getRowCount,
-                        startOffset = arrowBatch.getStartOffset,
-                        chunks = currentBatchChunks.toList)
-                      currentBatchChunks.clear()
-                      currentBatchIndex += 1
-                    }
-                  case _ => // Not an ExecutePlanResponse with ArrowBatch, 
ignore
-                }
-                super.onMessage(message)
-              }
-            },
-            headers)
-        }
-      }
-    }
-
-    // Get all batch information
-    def getBatchInfos: Seq[BatchInfo] = batches.toSeq
-
-    def clear(): Unit = {
-      currentBatchIndex = 0
-      currentBatchChunks.clear()
-      batches.clear()
-    }
-  }
-
-  test("Arrow batch result chunking") {
-    // This test validates that the client can correctly reassemble chunked 
Arrow batches
-    // using SequenceInputStream as implemented in SparkResult.processResponses
-
-    // Two cases are tested here:
-    // (a) client preferred chunk size is set: the server should respect it
-    // (b) client preferred chunk size is not set: the server should use its 
own max chunk size
-    Seq((Some(1024), None), (None, Some(1024))).foreach {
-      case (preferredChunkSizeOpt, maxChunkSizeOpt) =>
-        // Create interceptor to capture chunk information
-        val arrowBatchInterceptor = new ArrowBatchInterceptor()
-
-        try {
-          // Set preferred chunk size if specified and add interceptor
-          preferredChunkSizeOpt match {
-            case Some(size) =>
-              spark = createSparkSession(
-                
_.preferredArrowChunkSize(Some(size)).interceptor(arrowBatchInterceptor))
-            case None =>
-              spark = createSparkSession(_.interceptor(arrowBatchInterceptor))
-          }
-          // Set server max chunk size if specified
-          maxChunkSizeOpt.foreach { size =>
-            
spark.conf.set("spark.connect.session.resultChunking.maxChunkSize", 
size.toString)
-          }
-
-          val sqlQuery =
-            "select id, CAST(id + 0.5 AS DOUBLE) as double_val from range(0, 
2000, 1, 4)"
-
-          // Execute the query using withResult to access SparkResult object
-          spark.sql(sqlQuery).withResult { result =>
-            // Verify the results are correct and complete
-            assert(result.length == 2000)
-
-            // Get batch information from interceptor
-            val batchInfos = arrowBatchInterceptor.getBatchInfos
-
-            // Assert there are 4 batches (partitions) in total
-            assert(batchInfos.length == 4)
-
-            // Validate chunk information for each batch
-            val maxChunkSize = 
preferredChunkSizeOpt.orElse(maxChunkSizeOpt).get
-            batchInfos.foreach { batch =>
-              // In this example, the max chunk size is set to a small value,
-              // so each Arrow batch should be split into multiple chunks
-              assert(batch.totalChunks > 5)
-              assert(batch.chunks.nonEmpty)
-              assert(batch.chunks.length == batch.totalChunks)
-              batch.chunks.zipWithIndex.foreach { case (chunk, expectedIndex) 
=>
-                assert(chunk.chunkIndex == expectedIndex)
-                assert(chunk.numChunksInBatch == batch.totalChunks)
-                assert(chunk.rowCount == batch.rowCount)
-                assert(chunk.startOffset == batch.startOffset)
-                assert(chunk.dataSize > 0)
-                assert(chunk.dataSize <= maxChunkSize)
-              }
-            }
-
-            // Validate data integrity across the range to ensure chunking 
didn't corrupt anything
-            val rows = result.toArray
-            var expectedId = 0L
-            rows.foreach { row =>
-              assert(row.getLong(0) == expectedId)
-              val expectedDouble = expectedId + 0.5
-              val actualDouble = row.getDouble(1)
-              assert(math.abs(actualDouble - expectedDouble) < 0.001)
-              expectedId += 1
-            }
-          }
-        } finally {
-          // Clean up configurations
-          maxChunkSizeOpt.foreach { _ =>
-            
spark.conf.unset("spark.connect.session.resultChunking.maxChunkSize")
-          }
-          arrowBatchInterceptor.clear()
-        }
-    }
-  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
index a239775a3a86..efb6c721876c 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
@@ -187,27 +187,19 @@ object SparkConnectServerUtils {
   }
 
   def createSparkSession(): SparkSession = {
-    createSparkSession(identity)
-  }
-
-  def createSparkSession(
-      customBuilderFunc: SparkConnectClient.Builder => 
SparkConnectClient.Builder)
-      : SparkSession = {
     SparkConnectServerUtils.start()
 
-    var builder = SparkConnectClient
-      .builder()
-      .userId("test")
-      .port(port)
-      .retryPolicy(
-        RetryPolicy
-          .defaultPolicy()
-          .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, 
"s"))))
-
-    builder = customBuilderFunc(builder)
     val spark = SparkSession
       .builder()
-      .client(builder.build())
+      .client(
+        SparkConnectClient
+          .builder()
+          .userId("test")
+          .port(port)
+          .retryPolicy(RetryPolicy
+            .defaultPolicy()
+            .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, 
"s"))))
+          .build())
       .create()
 
     // Execute an RPC which will get retried until the server is up.
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index e5fd16a7c261..fa32eba91eb2 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -138,22 +138,6 @@ private[sql] class SparkConnectClient(
       .setSessionId(sessionId)
       .setClientType(userAgent)
       .addAllTags(tags.get.toSeq.asJava)
-
-    // Add request option to allow result chunking.
-    if (configuration.allowArrowBatchChunking) {
-      val chunkingOptionsBuilder = proto.ResultChunkingOptions
-        .newBuilder()
-        .setAllowArrowBatchChunking(true)
-      configuration.preferredArrowChunkSize.foreach { size =>
-        chunkingOptionsBuilder.setPreferredArrowChunkSize(size)
-      }
-      request.addRequestOptions(
-        proto.ExecutePlanRequest.RequestOption
-          .newBuilder()
-          .setResultChunkingOptions(chunkingOptionsBuilder.build())
-          .build())
-    }
-
     serverSideSessionId.foreach(session => 
request.setClientObservedServerSideSessionId(session))
     operationId.foreach { opId =>
       require(
@@ -348,16 +332,6 @@ private[sql] class SparkConnectClient(
 
   def copy(): SparkConnectClient = configuration.toSparkConnectClient
 
-  /**
-   * Returns whether arrow batch chunking is allowed.
-   */
-  def allowArrowBatchChunking: Boolean = configuration.allowArrowBatchChunking
-
-  /**
-   * Returns the preferred arrow chunk size in bytes.
-   */
-  def preferredArrowChunkSize: Option[Int] = 
configuration.preferredArrowChunkSize
-
   /**
    * Add a single artifact to the client session.
    *
@@ -783,21 +757,6 @@ object SparkConnectClient {
       this
     }
 
-    def allowArrowBatchChunking(allow: Boolean): Builder = {
-      _configuration = _configuration.copy(allowArrowBatchChunking = allow)
-      this
-    }
-
-    def allowArrowBatchChunking: Boolean = 
_configuration.allowArrowBatchChunking
-
-    def preferredArrowChunkSize(size: Option[Int]): Builder = {
-      size.foreach(s => require(s > 0, "preferredArrowChunkSize must be 
positive"))
-      _configuration = _configuration.copy(preferredArrowChunkSize = size)
-      this
-    }
-
-    def preferredArrowChunkSize: Option[Int] = 
_configuration.preferredArrowChunkSize
-
     def build(): SparkConnectClient = _configuration.toSparkConnectClient
   }
 
@@ -842,9 +801,7 @@ object SparkConnectClient {
       interceptors: List[ClientInterceptor] = List.empty,
       sessionId: Option[String] = None,
       grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
-      grpcMaxRecursionLimit: Int = 
ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
-      allowArrowBatchChunking: Boolean = true,
-      preferredArrowChunkSize: Option[Int] = None) {
+      grpcMaxRecursionLimit: Int = 
ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
 
     private def isLocal = host.equals("localhost")
 
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 43265e55a0ca..ef55edd10c8a 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -16,21 +16,18 @@
  */
 package org.apache.spark.sql.connect.client
 
-import java.io.SequenceInputStream
 import java.lang.ref.Cleaner
 import java.util.Objects
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
-import com.google.protobuf.ByteString
 import org.apache.arrow.memory.BufferAllocator
 import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
 import org.apache.arrow.vector.types.pojo
 
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
-import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, 
UnboundRowEncoder}
@@ -45,8 +42,7 @@ private[sql] class SparkResult[T](
     allocator: BufferAllocator,
     encoder: AgnosticEncoder[T],
     timeZoneId: String)
-    extends AutoCloseable
-    with Logging { self =>
+    extends AutoCloseable { self =>
 
   case class StageInfo(
       stageId: Long,
@@ -122,7 +118,6 @@ private[sql] class SparkResult[T](
       stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
     var nonEmpty = false
     var stop = false
-    val arrowBatchChunksToAssemble = mutable.Buffer.empty[ByteString]
     while (!stop && responses.hasNext) {
       val response = responses.next()
 
@@ -156,96 +151,55 @@ private[sql] class SparkResult[T](
         stop |= stopOnSchema
       }
       if (response.hasArrowBatch) {
-        val arrowBatch = response.getArrowBatch
-        logDebug(
-          s"Received arrow batch rows=${arrowBatch.getRowCount} " +
-            s"Number of chunks in batch=${arrowBatch.getNumChunksInBatch} " +
-            s"Chunk index=${arrowBatch.getChunkIndex} " +
-            s"size=${arrowBatch.getData.size()}")
-
-        if (arrowBatchChunksToAssemble.nonEmpty) {
-          // Expect next chunk of the same batch
-          if (arrowBatch.getChunkIndex != arrowBatchChunksToAssemble.size) {
-            throw new IllegalStateException(
-              s"Expected chunk index ${arrowBatchChunksToAssemble.size} of the 
" +
-                s"arrow batch but got ${arrowBatch.getChunkIndex}.")
-          }
-        } else {
-          // Expect next batch
-          if (arrowBatch.hasStartOffset) {
-            val expectedStartOffset = arrowBatch.getStartOffset
-            if (numRecords != expectedStartOffset) {
-              throw new IllegalStateException(
-                s"Expected arrow batch to start at row offset $numRecords in 
results, " +
-                  s"but received arrow batch starting at offset 
$expectedStartOffset.")
-            }
-          }
-          if (arrowBatch.getChunkIndex != 0) {
-            throw new IllegalStateException(
-              s"Expected chunk index 0 of the next arrow batch " +
-                s"but got ${arrowBatch.getChunkIndex}.")
-          }
+        val ipcStreamBytes = response.getArrowBatch.getData
+        val expectedNumRows = response.getArrowBatch.getRowCount
+        val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
+        if (arrowSchema == null) {
+          arrowSchema = reader.schema
+          stop |= stopOnArrowSchema
+        } else if (arrowSchema != reader.schema) {
+          throw new IllegalStateException(
+            s"""Schema Mismatch between expected and received schema:
+               |=== Expected Schema ===
+               |$arrowSchema
+               |=== Received Schema ===
+               |${reader.schema}
+               |""".stripMargin)
         }
-
-        arrowBatchChunksToAssemble += arrowBatch.getData
-
-        // Assemble the chunks to an arrow batch to process if
-        // (a) chunking is not enabled (numChunksInBatch is not set or is 0,
-        //     in this case, it is the single chunk in the batch)
-        // (b) or the client has received all chunks of the batch.
-        if (!arrowBatch.hasNumChunksInBatch ||
-          arrowBatch.getNumChunksInBatch == 0 ||
-          arrowBatchChunksToAssemble.size == arrowBatch.getNumChunksInBatch) {
-
-          val numChunks = arrowBatchChunksToAssemble.size
-          val inputStreams =
-            
arrowBatchChunksToAssemble.map(_.newInput()).iterator.asJavaEnumeration
-          val input = new SequenceInputStream(inputStreams)
-          arrowBatchChunksToAssemble.clear()
-          logDebug(s"Assembling arrow batch from $numChunks chunks.")
-
-          val expectedNumRows = arrowBatch.getRowCount
-          val reader = new MessageIterator(input, allocator)
-          if (arrowSchema == null) {
-            arrowSchema = reader.schema
-            stop |= stopOnArrowSchema
-          } else if (arrowSchema != reader.schema) {
-            throw new IllegalStateException(
-              s"""Schema Mismatch between expected and received schema:
-                 |=== Expected Schema ===
-                 |$arrowSchema
-                 |=== Received Schema ===
-                 |${reader.schema}
-                 |""".stripMargin)
-          }
-          if (structType == null) {
-            // If the schema is not available yet, fallback to the arrow 
schema.
-            structType = ArrowUtils.fromArrowSchema(reader.schema)
-          }
-
-          var numRecordsInBatch = 0
-          val messages = Seq.newBuilder[ArrowMessage]
-          while (reader.hasNext) {
-            val message = reader.next()
-            message match {
-              case batch: ArrowRecordBatch =>
-                numRecordsInBatch += batch.getLength
-              case _ =>
-            }
-            messages += message
-          }
-          if (numRecordsInBatch != expectedNumRows) {
+        if (structType == null) {
+          // If the schema is not available yet, fallback to the arrow schema.
+          structType = ArrowUtils.fromArrowSchema(reader.schema)
+        }
+        if (response.getArrowBatch.hasStartOffset) {
+          val expectedStartOffset = response.getArrowBatch.getStartOffset
+          if (numRecords != expectedStartOffset) {
             throw new IllegalStateException(
-              s"Expected $expectedNumRows rows in arrow batch but got 
$numRecordsInBatch.")
+              s"Expected arrow batch to start at row offset $numRecords in 
results, " +
+                s"but received arrow batch starting at offset 
$expectedStartOffset.")
           }
-          // Skip the entire result if it is empty.
-          if (numRecordsInBatch > 0) {
-            numRecords += numRecordsInBatch
-            resultMap.put(nextResultIndex, (reader.bytesRead, 
messages.result()))
-            nextResultIndex += 1
-            nonEmpty |= true
-            stop |= stopOnFirstNonEmptyResponse
+        }
+        var numRecordsInBatch = 0
+        val messages = Seq.newBuilder[ArrowMessage]
+        while (reader.hasNext) {
+          val message = reader.next()
+          message match {
+            case batch: ArrowRecordBatch =>
+              numRecordsInBatch += batch.getLength
+            case _ =>
           }
+          messages += message
+        }
+        if (numRecordsInBatch != expectedNumRows) {
+          throw new IllegalStateException(
+            s"Expected $expectedNumRows rows in arrow batch but got 
$numRecordsInBatch.")
+        }
+        // Skip the entire result if it is empty.
+        if (numRecordsInBatch > 0) {
+          numRecords += numRecordsInBatch
+          resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
+          nextResultIndex += 1
+          nonEmpty |= true
+          stop |= stopOnFirstNonEmptyResponse
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to