This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 25a11c4eadef Revert "[SPARK-53525][CONNECT][FOLLOWUP] Spark Connect
ArrowBatch Result Chunking - Scala Client"
25a11c4eadef is described below
commit 25a11c4eadef3cf83bd0bade11fbdfae24f57c40
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Wed Nov 5 21:39:08 2025 -0800
Revert "[SPARK-53525][CONNECT][FOLLOWUP] Spark Connect ArrowBatch Result
Chunking - Scala Client"
This reverts commit bc0f6f7a8a0db3d56b106f23956aa6e6e999d99d.
---
.../spark/sql/connect/ClientE2ETestSuite.scala | 159 +--------------------
.../sql/connect/test/RemoteSparkSession.scala | 26 ++--
.../sql/connect/client/SparkConnectClient.scala | 45 +-----
.../spark/sql/connect/client/SparkResult.scala | 138 ++++++------------
4 files changed, 57 insertions(+), 311 deletions(-)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 450ff8ca6249..8c336b6fa6d5 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -26,14 +26,12 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.jdk.CollectionConverters._
-import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor,
ForwardingClientCall, ForwardingClientCallListener, MethodDescriptor}
import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkArithmeticException, SparkException,
SparkUpgradeException}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
-import org.apache.spark.connect.proto
import org.apache.spark.internal.config.ConfigBuilder
import org.apache.spark.sql.{functions, AnalysisException, Observation, Row,
SaveMode}
import
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
NoSuchNamespaceException, TableAlreadyExistsException,
TempTableAlreadyExistsException}
@@ -43,7 +41,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient,
SparkResult}
import org.apache.spark.sql.connect.test.{ConnectFunSuite,
IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper}
-import
org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession,
port}
+import org.apache.spark.sql.connect.test.SparkConnectServerUtils.port
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
@@ -1850,161 +1848,6 @@ class ClientE2ETestSuite
checkAnswer(df, Seq.empty)
}
}
-
- // Helper class to capture Arrow batch chunk information from gRPC responses
- private class ArrowBatchInterceptor extends ClientInterceptor {
- case class BatchInfo(
- batchIndex: Int,
- rowCount: Long,
- startOffset: Long,
- chunks: Seq[ChunkInfo]) {
- def totalChunks: Int = chunks.length
- }
-
- case class ChunkInfo(
- batchIndex: Int,
- chunkIndex: Int,
- numChunksInBatch: Int,
- rowCount: Long,
- startOffset: Long,
- dataSize: Int)
-
- private val batches: mutable.Buffer[BatchInfo] = mutable.Buffer.empty
- private var currentBatchIndex: Int = 0
- private val currentBatchChunks: mutable.Buffer[ChunkInfo] =
mutable.Buffer.empty
-
- override def interceptCall[ReqT, RespT](
- method: MethodDescriptor[ReqT, RespT],
- callOptions: CallOptions,
- next: Channel): ClientCall[ReqT, RespT] = {
- new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
- next.newCall(method, callOptions)) {
- override def start(
- responseListener: ClientCall.Listener[RespT],
- headers: io.grpc.Metadata): Unit = {
- super.start(
- new
ForwardingClientCallListener.SimpleForwardingClientCallListener[RespT](
- responseListener) {
- override def onMessage(message: RespT): Unit = {
- message match {
- case response: proto.ExecutePlanResponse if
response.hasArrowBatch =>
- val arrowBatch = response.getArrowBatch
- // Track chunk information for every chunk
- currentBatchChunks += ChunkInfo(
- batchIndex = currentBatchIndex,
- chunkIndex = arrowBatch.getChunkIndex.toInt,
- numChunksInBatch = arrowBatch.getNumChunksInBatch.toInt,
- rowCount = arrowBatch.getRowCount,
- startOffset = arrowBatch.getStartOffset,
- dataSize = arrowBatch.getData.size())
- // When we receive the last chunk, create the BatchInfo
- if (currentBatchChunks.length ==
arrowBatch.getNumChunksInBatch) {
- batches += BatchInfo(
- batchIndex = currentBatchIndex,
- rowCount = arrowBatch.getRowCount,
- startOffset = arrowBatch.getStartOffset,
- chunks = currentBatchChunks.toList)
- currentBatchChunks.clear()
- currentBatchIndex += 1
- }
- case _ => // Not an ExecutePlanResponse with ArrowBatch,
ignore
- }
- super.onMessage(message)
- }
- },
- headers)
- }
- }
- }
-
- // Get all batch information
- def getBatchInfos: Seq[BatchInfo] = batches.toSeq
-
- def clear(): Unit = {
- currentBatchIndex = 0
- currentBatchChunks.clear()
- batches.clear()
- }
- }
-
- test("Arrow batch result chunking") {
- // This test validates that the client can correctly reassemble chunked
Arrow batches
- // using SequenceInputStream as implemented in SparkResult.processResponses
-
- // Two cases are tested here:
- // (a) client preferred chunk size is set: the server should respect it
- // (b) client preferred chunk size is not set: the server should use its
own max chunk size
- Seq((Some(1024), None), (None, Some(1024))).foreach {
- case (preferredChunkSizeOpt, maxChunkSizeOpt) =>
- // Create interceptor to capture chunk information
- val arrowBatchInterceptor = new ArrowBatchInterceptor()
-
- try {
- // Set preferred chunk size if specified and add interceptor
- preferredChunkSizeOpt match {
- case Some(size) =>
- spark = createSparkSession(
-
_.preferredArrowChunkSize(Some(size)).interceptor(arrowBatchInterceptor))
- case None =>
- spark = createSparkSession(_.interceptor(arrowBatchInterceptor))
- }
- // Set server max chunk size if specified
- maxChunkSizeOpt.foreach { size =>
-
spark.conf.set("spark.connect.session.resultChunking.maxChunkSize",
size.toString)
- }
-
- val sqlQuery =
- "select id, CAST(id + 0.5 AS DOUBLE) as double_val from range(0,
2000, 1, 4)"
-
- // Execute the query using withResult to access SparkResult object
- spark.sql(sqlQuery).withResult { result =>
- // Verify the results are correct and complete
- assert(result.length == 2000)
-
- // Get batch information from interceptor
- val batchInfos = arrowBatchInterceptor.getBatchInfos
-
- // Assert there are 4 batches (partitions) in total
- assert(batchInfos.length == 4)
-
- // Validate chunk information for each batch
- val maxChunkSize =
preferredChunkSizeOpt.orElse(maxChunkSizeOpt).get
- batchInfos.foreach { batch =>
- // In this example, the max chunk size is set to a small value,
- // so each Arrow batch should be split into multiple chunks
- assert(batch.totalChunks > 5)
- assert(batch.chunks.nonEmpty)
- assert(batch.chunks.length == batch.totalChunks)
- batch.chunks.zipWithIndex.foreach { case (chunk, expectedIndex)
=>
- assert(chunk.chunkIndex == expectedIndex)
- assert(chunk.numChunksInBatch == batch.totalChunks)
- assert(chunk.rowCount == batch.rowCount)
- assert(chunk.startOffset == batch.startOffset)
- assert(chunk.dataSize > 0)
- assert(chunk.dataSize <= maxChunkSize)
- }
- }
-
- // Validate data integrity across the range to ensure chunking
didn't corrupt anything
- val rows = result.toArray
- var expectedId = 0L
- rows.foreach { row =>
- assert(row.getLong(0) == expectedId)
- val expectedDouble = expectedId + 0.5
- val actualDouble = row.getDouble(1)
- assert(math.abs(actualDouble - expectedDouble) < 0.001)
- expectedId += 1
- }
- }
- } finally {
- // Clean up configurations
- maxChunkSizeOpt.foreach { _ =>
-
spark.conf.unset("spark.connect.session.resultChunking.maxChunkSize")
- }
- arrowBatchInterceptor.clear()
- }
- }
- }
}
private[sql] case class ClassData(a: String, b: Int)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
index a239775a3a86..efb6c721876c 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/RemoteSparkSession.scala
@@ -187,27 +187,19 @@ object SparkConnectServerUtils {
}
def createSparkSession(): SparkSession = {
- createSparkSession(identity)
- }
-
- def createSparkSession(
- customBuilderFunc: SparkConnectClient.Builder =>
SparkConnectClient.Builder)
- : SparkSession = {
SparkConnectServerUtils.start()
- var builder = SparkConnectClient
- .builder()
- .userId("test")
- .port(port)
- .retryPolicy(
- RetryPolicy
- .defaultPolicy()
- .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30,
"s"))))
-
- builder = customBuilderFunc(builder)
val spark = SparkSession
.builder()
- .client(builder.build())
+ .client(
+ SparkConnectClient
+ .builder()
+ .userId("test")
+ .port(port)
+ .retryPolicy(RetryPolicy
+ .defaultPolicy()
+ .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30,
"s"))))
+ .build())
.create()
// Execute an RPC which will get retried until the server is up.
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index e5fd16a7c261..fa32eba91eb2 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -138,22 +138,6 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
-
- // Add request option to allow result chunking.
- if (configuration.allowArrowBatchChunking) {
- val chunkingOptionsBuilder = proto.ResultChunkingOptions
- .newBuilder()
- .setAllowArrowBatchChunking(true)
- configuration.preferredArrowChunkSize.foreach { size =>
- chunkingOptionsBuilder.setPreferredArrowChunkSize(size)
- }
- request.addRequestOptions(
- proto.ExecutePlanRequest.RequestOption
- .newBuilder()
- .setResultChunkingOptions(chunkingOptionsBuilder.build())
- .build())
- }
-
serverSideSessionId.foreach(session =>
request.setClientObservedServerSideSessionId(session))
operationId.foreach { opId =>
require(
@@ -348,16 +332,6 @@ private[sql] class SparkConnectClient(
def copy(): SparkConnectClient = configuration.toSparkConnectClient
- /**
- * Returns whether arrow batch chunking is allowed.
- */
- def allowArrowBatchChunking: Boolean = configuration.allowArrowBatchChunking
-
- /**
- * Returns the preferred arrow chunk size in bytes.
- */
- def preferredArrowChunkSize: Option[Int] =
configuration.preferredArrowChunkSize
-
/**
* Add a single artifact to the client session.
*
@@ -783,21 +757,6 @@ object SparkConnectClient {
this
}
- def allowArrowBatchChunking(allow: Boolean): Builder = {
- _configuration = _configuration.copy(allowArrowBatchChunking = allow)
- this
- }
-
- def allowArrowBatchChunking: Boolean =
_configuration.allowArrowBatchChunking
-
- def preferredArrowChunkSize(size: Option[Int]): Builder = {
- size.foreach(s => require(s > 0, "preferredArrowChunkSize must be
positive"))
- _configuration = _configuration.copy(preferredArrowChunkSize = size)
- this
- }
-
- def preferredArrowChunkSize: Option[Int] =
_configuration.preferredArrowChunkSize
-
def build(): SparkConnectClient = _configuration.toSparkConnectClient
}
@@ -842,9 +801,7 @@ object SparkConnectClient {
interceptors: List[ClientInterceptor] = List.empty,
sessionId: Option[String] = None,
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
- grpcMaxRecursionLimit: Int =
ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
- allowArrowBatchChunking: Boolean = true,
- preferredArrowChunkSize: Option[Int] = None) {
+ grpcMaxRecursionLimit: Int =
ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
private def isLocal = host.equals("localhost")
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 43265e55a0ca..ef55edd10c8a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -16,21 +16,18 @@
*/
package org.apache.spark.sql.connect.client
-import java.io.SequenceInputStream
import java.lang.ref.Cleaner
import java.util.Objects
import scala.collection.mutable
import scala.jdk.CollectionConverters._
-import com.google.protobuf.ByteString
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
import org.apache.arrow.vector.types.pojo
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder,
UnboundRowEncoder}
@@ -45,8 +42,7 @@ private[sql] class SparkResult[T](
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String)
- extends AutoCloseable
- with Logging { self =>
+ extends AutoCloseable { self =>
case class StageInfo(
stageId: Long,
@@ -122,7 +118,6 @@ private[sql] class SparkResult[T](
stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
var nonEmpty = false
var stop = false
- val arrowBatchChunksToAssemble = mutable.Buffer.empty[ByteString]
while (!stop && responses.hasNext) {
val response = responses.next()
@@ -156,96 +151,55 @@ private[sql] class SparkResult[T](
stop |= stopOnSchema
}
if (response.hasArrowBatch) {
- val arrowBatch = response.getArrowBatch
- logDebug(
- s"Received arrow batch rows=${arrowBatch.getRowCount} " +
- s"Number of chunks in batch=${arrowBatch.getNumChunksInBatch} " +
- s"Chunk index=${arrowBatch.getChunkIndex} " +
- s"size=${arrowBatch.getData.size()}")
-
- if (arrowBatchChunksToAssemble.nonEmpty) {
- // Expect next chunk of the same batch
- if (arrowBatch.getChunkIndex != arrowBatchChunksToAssemble.size) {
- throw new IllegalStateException(
- s"Expected chunk index ${arrowBatchChunksToAssemble.size} of the
" +
- s"arrow batch but got ${arrowBatch.getChunkIndex}.")
- }
- } else {
- // Expect next batch
- if (arrowBatch.hasStartOffset) {
- val expectedStartOffset = arrowBatch.getStartOffset
- if (numRecords != expectedStartOffset) {
- throw new IllegalStateException(
- s"Expected arrow batch to start at row offset $numRecords in
results, " +
- s"but received arrow batch starting at offset
$expectedStartOffset.")
- }
- }
- if (arrowBatch.getChunkIndex != 0) {
- throw new IllegalStateException(
- s"Expected chunk index 0 of the next arrow batch " +
- s"but got ${arrowBatch.getChunkIndex}.")
- }
+ val ipcStreamBytes = response.getArrowBatch.getData
+ val expectedNumRows = response.getArrowBatch.getRowCount
+ val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
+ if (arrowSchema == null) {
+ arrowSchema = reader.schema
+ stop |= stopOnArrowSchema
+ } else if (arrowSchema != reader.schema) {
+ throw new IllegalStateException(
+ s"""Schema Mismatch between expected and received schema:
+ |=== Expected Schema ===
+ |$arrowSchema
+ |=== Received Schema ===
+ |${reader.schema}
+ |""".stripMargin)
}
-
- arrowBatchChunksToAssemble += arrowBatch.getData
-
- // Assemble the chunks to an arrow batch to process if
- // (a) chunking is not enabled (numChunksInBatch is not set or is 0,
- // in this case, it is the single chunk in the batch)
- // (b) or the client has received all chunks of the batch.
- if (!arrowBatch.hasNumChunksInBatch ||
- arrowBatch.getNumChunksInBatch == 0 ||
- arrowBatchChunksToAssemble.size == arrowBatch.getNumChunksInBatch) {
-
- val numChunks = arrowBatchChunksToAssemble.size
- val inputStreams =
-
arrowBatchChunksToAssemble.map(_.newInput()).iterator.asJavaEnumeration
- val input = new SequenceInputStream(inputStreams)
- arrowBatchChunksToAssemble.clear()
- logDebug(s"Assembling arrow batch from $numChunks chunks.")
-
- val expectedNumRows = arrowBatch.getRowCount
- val reader = new MessageIterator(input, allocator)
- if (arrowSchema == null) {
- arrowSchema = reader.schema
- stop |= stopOnArrowSchema
- } else if (arrowSchema != reader.schema) {
- throw new IllegalStateException(
- s"""Schema Mismatch between expected and received schema:
- |=== Expected Schema ===
- |$arrowSchema
- |=== Received Schema ===
- |${reader.schema}
- |""".stripMargin)
- }
- if (structType == null) {
- // If the schema is not available yet, fallback to the arrow
schema.
- structType = ArrowUtils.fromArrowSchema(reader.schema)
- }
-
- var numRecordsInBatch = 0
- val messages = Seq.newBuilder[ArrowMessage]
- while (reader.hasNext) {
- val message = reader.next()
- message match {
- case batch: ArrowRecordBatch =>
- numRecordsInBatch += batch.getLength
- case _ =>
- }
- messages += message
- }
- if (numRecordsInBatch != expectedNumRows) {
+ if (structType == null) {
+ // If the schema is not available yet, fallback to the arrow schema.
+ structType = ArrowUtils.fromArrowSchema(reader.schema)
+ }
+ if (response.getArrowBatch.hasStartOffset) {
+ val expectedStartOffset = response.getArrowBatch.getStartOffset
+ if (numRecords != expectedStartOffset) {
throw new IllegalStateException(
- s"Expected $expectedNumRows rows in arrow batch but got
$numRecordsInBatch.")
+ s"Expected arrow batch to start at row offset $numRecords in
results, " +
+ s"but received arrow batch starting at offset
$expectedStartOffset.")
}
- // Skip the entire result if it is empty.
- if (numRecordsInBatch > 0) {
- numRecords += numRecordsInBatch
- resultMap.put(nextResultIndex, (reader.bytesRead,
messages.result()))
- nextResultIndex += 1
- nonEmpty |= true
- stop |= stopOnFirstNonEmptyResponse
+ }
+ var numRecordsInBatch = 0
+ val messages = Seq.newBuilder[ArrowMessage]
+ while (reader.hasNext) {
+ val message = reader.next()
+ message match {
+ case batch: ArrowRecordBatch =>
+ numRecordsInBatch += batch.getLength
+ case _ =>
}
+ messages += message
+ }
+ if (numRecordsInBatch != expectedNumRows) {
+ throw new IllegalStateException(
+ s"Expected $expectedNumRows rows in arrow batch but got
$numRecordsInBatch.")
+ }
+ // Skip the entire result if it is empty.
+ if (numRecordsInBatch > 0) {
+ numRecords += numRecordsInBatch
+ resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
+ nextResultIndex += 1
+ nonEmpty |= true
+ stop |= stopOnFirstNonEmptyResponse
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]