This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c36b7e58d042 [SPARK-54696][CONNECT] Clean-up ArrowBuffers in Connect
c36b7e58d042 is described below

commit c36b7e58d0422a13228252657e4cff26a762a228
Author: Herman van Hövell <[email protected]>
AuthorDate: Tue Dec 16 08:44:00 2025 +0900

    [SPARK-54696][CONNECT] Clean-up ArrowBuffers in Connect
    
    ### What changes were proposed in this pull request?
    This PR fixes a memory leak in Spark Connect LocalRelations.
    
    ... more details TBD ...
    
    ### Why are the changes needed?
    It fixes a stability issue.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests.
    A Connect Planner Test TBD
    Longevity tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #53452 from hvanhovell/fix-arrow-local-relations.
    
    Authored-by: Herman van Hövell <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../apache/spark/sql/util}/CloseableIterator.scala |  2 +-
 .../sql/util}/ConcatenatingArrowStreamReader.scala |  8 +-
 .../connect/client/arrow/ArrowEncoderSuite.scala   |  2 +-
 .../apache/spark/sql/connect/SparkSession.scala    |  4 +-
 .../sql/connect/StreamingQueryListenerBus.scala    |  2 +-
 .../client/CustomSparkConnectBlockingStub.scala    |  1 +
 .../ExecutePlanResponseReattachableIterator.scala  |  1 +
 .../connect/client/GrpcExceptionConverter.scala    |  1 +
 .../sql/connect/client/GrpcRetryHandler.scala      |  1 +
 .../sql/connect/client/ResponseValidator.scala     |  1 +
 .../sql/connect/client/SparkConnectClient.scala    |  1 +
 .../spark/sql/connect/client/SparkResult.scala     |  4 +-
 .../connect/client/arrow/ArrowDeserializer.scala   |  2 +-
 .../sql/connect/client/arrow/ArrowSerializer.scala |  3 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 38 +++------
 .../spark/sql/connect/SparkConnectServerTest.scala |  3 +-
 .../sql/execution/arrow/ArrowConverters.scala      | 96 +++++++++-------------
 .../sql/execution/arrow/ArrowConvertersSuite.scala | 16 ++--
 18 files changed, 83 insertions(+), 103 deletions(-)

diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
 b/sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
similarity index 97%
rename from 
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
rename to 
sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
index 9de585503a50..dc38c75d3ce7 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.util
 
 private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable 
{ self =>
   def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with 
AutoCloseable {
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
similarity index 95%
rename from 
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
rename to 
sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
index 90963c831c25..5de53a568a7d 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.spark.sql.connect.client.arrow
+package org.apache.spark.sql.util
 
 import java.io.{InputStream, IOException}
 import java.nio.channels.Channels
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.types.pojo.Schema
  * closes its messages when it consumes them. In order to prevent that from 
happening in
  * non-destructive mode we clone the messages before passing them to the 
reading logic.
  */
-class ConcatenatingArrowStreamReader(
+private[sql] class ConcatenatingArrowStreamReader(
     allocator: BufferAllocator,
     input: Iterator[AbstractMessageIterator],
     destructive: Boolean)
@@ -128,7 +128,7 @@ class ConcatenatingArrowStreamReader(
   override def closeReadSource(): Unit = ()
 }
 
-trait AbstractMessageIterator extends Iterator[ArrowMessage] {
+private[sql] trait AbstractMessageIterator extends Iterator[ArrowMessage] {
   def schema: Schema
   def bytesRead: Long
 }
@@ -137,7 +137,7 @@ trait AbstractMessageIterator extends 
Iterator[ArrowMessage] {
  * Decode an Arrow IPC stream into individual messages. Please note that this 
iterator MUST have a
  * valid IPC stream as its input, otherwise construction will fail.
  */
-class MessageIterator(input: InputStream, allocator: BufferAllocator)
+private[sql] class MessageIterator(input: InputStream, allocator: 
BufferAllocator)
     extends AbstractMessageIterator {
   private[this] val in = new ReadChannel(Channels.newChannel(input))
   private[this] val reader = new MessageChannelReader(in, allocator)
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index d24369ff5fc7..52a503d62601 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -41,10 +41,10 @@ import 
org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
 import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
 import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._
 import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
-import org.apache.spark.sql.connect.client.CloseableIterator
 import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
 import org.apache.spark.sql.connect.test.ConnectFunSuite
 import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, 
Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata, 
SQLUserDefinedType, StringType, StructType, UserDefinedType, 
YearMonthIntervalType}
+import org.apache.spark.sql.util.CloseableIterator
 import org.apache.spark.unsafe.types.VariantVal
 import org.apache.spark.util.{MaybeNull, SparkStringUtils}
 
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index daa2cc2001e4..42dd1a2b9979 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -51,13 +51,13 @@ import 
org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
BoxedLongEncoder, UnboundRowEncoder}
 import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
 import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, 
SparkConnectClient, SparkResult}
+import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, 
SparkResult}
 import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
 import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf, 
SubqueryExpression}
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ExecutionListenerManager
+import org.apache.spark.sql.util.{CloseableIterator, ExecutionListenerManager}
 import org.apache.spark.util.ArrayImplicits._
 
 /**
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
index 52b0ea24e9e3..754823835146 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
@@ -23,9 +23,9 @@ import scala.jdk.CollectionConverters._
 
 import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, 
StreamingQueryEventType}
 import org.apache.spark.internal.{Logging, LogKeys}
-import org.apache.spark.sql.connect.client.CloseableIterator
 import org.apache.spark.sql.streaming.StreamingQueryListener
 import org.apache.spark.sql.streaming.StreamingQueryListener.{Event, 
QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
+import org.apache.spark.sql.util.CloseableIterator
 
 class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
   private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 913f068fcf34..715da0df7349 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._
 import io.grpc.ManagedChannel
 
 import org.apache.spark.connect.proto._
+import org.apache.spark.sql.util.CloseableIterator
 
 private[connect] class CustomSparkConnectBlockingStub(
     channel: ManagedChannel,
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index f3c13c9c2c4d..131a2e77cc43 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -28,6 +28,7 @@ import io.grpc.stub.StreamObserver
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryException
+import org.apache.spark.sql.util.WrappedCloseableIterator
 
 /**
  * Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index d3dae47f4c47..7e0b0949fcf1 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -35,6 +35,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.trees.Origin
 import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
 import org.apache.spark.util.ArrayImplicits._
 
 /**
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 3f4558ee97da..d92dc902fedc 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -23,6 +23,7 @@ import io.grpc.stub.StreamObserver
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys.{ERROR, NUM_RETRY, POLICY, 
RETRY_WAIT_TIME}
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
 
 private[sql] class GrpcRetryHandler(
     private val policies: Seq[RetryPolicy],
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
index 03548120457f..6cf39b8d1879 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
@@ -23,6 +23,7 @@ import io.grpc.{Status, StatusRuntimeException}
 import io.grpc.stub.StreamObserver
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
 
 // This is common logic to be shared between different stub instances to keep 
the server-side
 // session id and to validate responses as seen by the client.
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 1a7d062470e1..5d36fc45f948 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -40,6 +40,7 @@ import org.apache.spark.internal.LogKeys.{ERROR, RATIO, SIZE, 
TIME}
 import org.apache.spark.sql.connect.RuntimeConfig
 import org.apache.spark.sql.connect.common.ProtoUtils
 import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.util.CloseableIterator
 import org.apache.spark.util.SparkSystemUtils
 
 /**
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 43265e55a0ca..4199801d8505 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -35,10 +35,10 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, 
UnboundRowEncoder}
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, 
ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
+import org.apache.spark.sql.connect.client.arrow.ArrowDeserializingIterator
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
LiteralValueProtoConverter}
 import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{AbstractMessageIterator, ArrowUtils, 
CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
 
 private[sql] class SparkResult[T](
     responses: CloseableIterator[proto.ExecutePlanResponse],
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 8d5811dda8f3..82029025a7f0 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.CloseableIterator
 import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors}
 import org.apache.spark.sql.types.Decimal
+import org.apache.spark.sql.util.{CloseableIterator, 
ConcatenatingArrowStreamReader, MessageIterator}
 import org.apache.spark.unsafe.types.VariantVal
 
 /**
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 73c9a991ab6a..d547c81afe5a 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -38,10 +38,9 @@ import 
org.apache.spark.sql.catalyst.DefinedByConstructorParams
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, 
SparkIntervalUtils}
-import org.apache.spark.sql.connect.client.CloseableIterator
 import org.apache.spark.sql.errors.ExecutionErrors
 import org.apache.spark.sql.types.Decimal
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator}
 import org.apache.spark.unsafe.types.VariantVal
 
 /**
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 7e17a935f599..0d95fe31e063 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -29,7 +29,7 @@ import com.google.protobuf.{Any => ProtoAny, ByteString}
 import io.grpc.{Context, Status, StatusRuntimeException}
 import io.grpc.stub.StreamObserver
 
-import org.apache.spark.{SparkClassNotFoundException, SparkEnv, 
SparkException, TaskContext}
+import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException}
 import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
 import org.apache.spark.connect.proto
@@ -1491,9 +1491,12 @@ class SparkConnectPlanner(
     }
 
     if (rel.hasData) {
-      val (rows, structType) =
-        ArrowConverters.fromIPCStream(rel.getData.toByteArray, 
TaskContext.get())
-      buildLocalRelationFromRows(rows, structType, Option(schema))
+      val (rows, structType) = 
ArrowConverters.fromIPCStream(rel.getData.toByteArray)
+      try {
+        buildLocalRelationFromRows(rows, structType, Option(schema))
+      } finally {
+        rows.close()
+      }
     } else {
       if (schema == null) {
         throw InvalidInputErrors.schemaRequiredForLocalRelation()
@@ -1564,28 +1567,13 @@ class SparkConnectPlanner(
     }
 
     // Load and combine all batches
-    var combinedRows: Iterator[InternalRow] = Iterator.empty
-    var structType: StructType = null
-
-    for ((dataHash, batchIndex) <- dataHashes.zipWithIndex) {
-      val dataBytes = readChunkedCachedLocalRelationBlock(dataHash)
-      val (batchRows, batchStructType) =
-        ArrowConverters.fromIPCStream(dataBytes, TaskContext.get())
-
-      // For the first batch, set the schema; for subsequent batches, verify 
compatibility
-      if (batchIndex == 0) {
-        structType = batchStructType
-        combinedRows = batchRows
-
-      } else {
-        if (batchStructType != structType) {
-          throw 
InvalidInputErrors.chunkedCachedLocalRelationChunksWithDifferentSchema()
-        }
-        combinedRows = combinedRows ++ batchRows
-      }
+    val (rows, structType) =
+      
ArrowConverters.fromIPCStream(dataHashes.iterator.map(readChunkedCachedLocalRelationBlock))
+    try {
+      buildLocalRelationFromRows(rows, structType, Option(schema))
+    } finally {
+      rows.close()
     }
-
-    buildLocalRelationFromRows(combinedRows, structType, Option(schema))
   }
 
   private def toStructTypeOrWrap(dt: DataType): StructType = dt match {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 7b9052bb9d2c..77ede8e852e8 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.classic
 import org.apache.spark.sql.connect
-import org.apache.spark.sql.connect.client.{CloseableIterator, 
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, 
RetryPolicy, SparkConnectClient, SparkConnectStubState}
+import org.apache.spark.sql.connect.client.{CustomSparkConnectBlockingStub, 
ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, 
SparkConnectStubState}
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 import org.apache.spark.sql.connect.config.Connect
@@ -38,6 +38,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession
 import org.apache.spark.sql.connect.dsl.plans._
 import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey, 
SparkConnectService}
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.util.CloseableIterator
 
 /**
  * Base class and utilities for a test suite that starts and tests the real 
SparkConnectService
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 683d8de25e0a..799996126e42 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -22,6 +22,7 @@ import java.nio.channels.{Channels, ReadableByteChannel}
 
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
 
 import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
 import org.apache.arrow.flatbuf.MessageHeader
@@ -42,12 +43,11 @@ import 
org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator, 
ConcatenatingArrowStreamReader, MessageIterator}
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
 import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils}
 import org.apache.spark.util.ArrayImplicits._
 
-
 /**
  * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow 
stream format.
  */
@@ -296,50 +296,33 @@ private[sql] object ArrowConverters extends Logging {
    * @param context Task Context for Spark
    */
   private[sql] class InternalRowIteratorFromIPCStream(
-      input: Array[Byte],
-      context: TaskContext) extends Iterator[InternalRow] {
-
-    // Keep all the resources we have opened in order, should be closed
-    // in reverse order finally.
-    private val resources = new ArrayBuffer[AutoCloseable]()
+      ipcStreams: Iterator[Array[Byte]],
+      context: TaskContext)
+    extends CloseableIterator[InternalRow] {
 
     // Create an allocator used for all Arrow related memory.
     protected val allocator: BufferAllocator = 
ArrowUtils.rootAllocator.newChildAllocator(
       s"to${this.getClass.getSimpleName}",
       0,
       Long.MaxValue)
-    resources.append(allocator)
 
-    private val reader = try {
-      new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
-    } catch {
-      case e: Exception =>
-        closeAll(resources.toSeq.reverse: _*)
-        throw new IllegalArgumentException(
-          s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
-    }
-    resources.append(reader)
-
-    private val root: VectorSchemaRoot = try {
-      reader.getVectorSchemaRoot
-    } catch {
-      case e: Exception =>
-        closeAll(resources.toSeq.reverse: _*)
-        throw new IllegalArgumentException(
-          s"Failed to read schema from IPC stream: ${e.getMessage}", e)
+    private val reader = {
+      val messages = ipcStreams.map { bytes =>
+        new MessageIterator(new ByteArrayInputStream(bytes), allocator)
+      }
+      new ConcatenatingArrowStreamReader(allocator, messages, destructive = 
true)
     }
-    resources.append(root)
 
     val schema: StructType = try {
-      ArrowUtils.fromArrowSchema(root.getSchema)
+      ArrowUtils.fromArrowSchema(reader.getVectorSchemaRoot.getSchema)
     } catch {
-      case e: Exception =>
-        closeAll(resources.toSeq.reverse: _*)
-        throw new IllegalArgumentException(s"Failed to convert Arrow schema: 
${e.getMessage}", e)
+      case NonFatal(e) =>
+        // Since this triggers a read (which involves allocating buffers) we 
have to clean-up.
+        close()
+        throw e
     }
 
-    // TODO: wrap in exception
-    private var rowIterator: Iterator[InternalRow] = 
vectorSchemaRootToIter(root)
+    private var rowIterator: Iterator[InternalRow] = Iterator.empty
 
     // Metrics to track batch processing
     private var _batchesLoaded: Int = 0
@@ -347,7 +330,7 @@ private[sql] object ArrowConverters extends Logging {
 
     if (context != null) {
       context.addTaskCompletionListener[Unit] { _ =>
-        closeAll(resources.toSeq.reverse: _*)
+        close()
       }
     }
 
@@ -355,28 +338,17 @@ private[sql] object ArrowConverters extends Logging {
     def batchesLoaded: Int = _batchesLoaded
     def totalRowsProcessed: Long = _totalRowsProcessed
 
-    // Loads the next batch from the Arrow reader and returns true or
-    // false if the next batch could be loaded.
-    private def loadNextBatch(): Boolean = {
-      if (reader.loadNextBatch()) {
-        rowIterator = vectorSchemaRootToIter(root)
-        _batchesLoaded += 1
-        true
-      } else {
-        false
-      }
-    }
-
     override def hasNext: Boolean = {
-      if (rowIterator.hasNext) {
-        true
-      } else {
-        if (!loadNextBatch()) {
-          false
+      while (!rowIterator.hasNext) {
+        if (reader.loadNextBatch()) {
+          rowIterator = vectorSchemaRootToIter(reader.getVectorSchemaRoot)
+          _batchesLoaded += 1
         } else {
-          hasNext
+          close()
+          return false
         }
       }
+      true
     }
 
     override def next(): InternalRow = {
@@ -386,6 +358,10 @@ private[sql] object ArrowConverters extends Logging {
       _totalRowsProcessed += 1
       rowIterator.next()
     }
+
+    override def close(): Unit = {
+      closeAll(reader, allocator)
+    }
   }
 
   /**
@@ -511,15 +487,21 @@ private[sql] object ArrowConverters extends Logging {
    * one schema and a varying number of record batches. Returns an iterator 
over the
    * created InternalRow.
    */
-  private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
-      (Iterator[InternalRow], StructType) = {
-    fromIPCStreamWithIterator(input, context)
+  private[sql] def fromIPCStream(input: Array[Byte]):
+    (CloseableIterator[InternalRow], StructType) = {
+    fromIPCStream(Iterator.single(input))
+  }
+
+  private[sql] def fromIPCStream(inputs: Iterator[Array[Byte]]):
+    (CloseableIterator[InternalRow], StructType) = {
+    val iterator = new InternalRowIteratorFromIPCStream(inputs, null)
+    (iterator, iterator.schema)
   }
 
   // Overloaded method for tests to access the iterator with metrics
   private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context: 
TaskContext):
-      (InternalRowIteratorFromIPCStream, StructType) = {
-    val iterator = new InternalRowIteratorFromIPCStream(input, context)
+    (InternalRowIteratorFromIPCStream, StructType) = {
+    val iterator = new 
InternalRowIteratorFromIPCStream(Iterator.single(input), context)
     (iterator, iterator.schema)
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index ccf6b63eb5de..f58a5b7ebd6a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1626,7 +1626,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
       writer.end()
     }
 
-    val (outputRowIter, outputSchema) = 
ArrowConverters.fromIPCStream(out.toByteArray, ctx)
+    val (outputRowIter, outputSchema) = ArrowConverters.
+      fromIPCStreamWithIterator(out.toByteArray, ctx)
     assert(outputSchema == schema)
     val res = outputRowIter.zipWithIndex.map { case (row, i) =>
       assert(row.getInt(0) == i)
@@ -1663,7 +1664,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
       writer.end()
     }
 
-    val (outputRowIter, outputSchema) = 
ArrowConverters.fromIPCStream(out.toByteArray, ctx)
+    val (outputRowIter, outputSchema) = ArrowConverters
+      .fromIPCStreamWithIterator(out.toByteArray, ctx)
     assert(outputSchema == schema)
     val outputRows = outputRowIter.zipWithIndex.map { case (row, i) =>
       assert(row.getInt(0) == i)
@@ -1760,7 +1762,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
     val invalidData = Array[Byte](1, 2, 3, 4, 5)
 
     intercept[Exception] {
-      ArrowConverters.fromIPCStream(invalidData, ctx)
+      ArrowConverters.fromIPCStreamWithIterator(invalidData, ctx)
     }
   }
 
@@ -1769,7 +1771,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
     val emptyData = Array.empty[Byte]
 
     intercept[Exception] {
-      ArrowConverters.fromIPCStream(emptyData, ctx)
+      ArrowConverters.fromIPCStreamWithIterator(emptyData, ctx)
     }
   }
 
@@ -1790,7 +1792,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
 
     // Test with null context - should still work but won't have cleanup 
registration
     val proj = UnsafeProjection.create(schema)
-    val (outputRowIter, outputSchema) = 
ArrowConverters.fromIPCStream(out.toByteArray, null)
+    val (outputRowIter, outputSchema) = ArrowConverters.
+      fromIPCStreamWithIterator(out.toByteArray, null)
     assert(outputSchema == schema)
     val outputRows = outputRowIter.map(proj(_).copy()).toList
     assert(outputRows.length == inputRows.length)
@@ -1884,7 +1887,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
       writer.end()
     }
 
-    val (outputRowIter, outputSchema) = 
ArrowConverters.fromIPCStream(out.toByteArray, ctx)
+    val (outputRowIter, outputSchema) = ArrowConverters.
+      fromIPCStreamWithIterator(out.toByteArray, ctx)
     val proj = UnsafeProjection.create(schema)
     assert(outputSchema == schema)
     val outputRows = outputRowIter.map(proj(_).copy()).toList


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to