This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 37ec5c4fab2d [SPARK-54696][CONNECT] Clean-up ArrowBuffers in Connect
37ec5c4fab2d is described below
commit 37ec5c4fab2d0f20139ac953008e616d3f3f2858
Author: Herman van Hövell <[email protected]>
AuthorDate: Tue Dec 16 08:44:00 2025 +0900
[SPARK-54696][CONNECT] Clean-up ArrowBuffers in Connect
### What changes were proposed in this pull request?
This PR fixes a memory leak in Spark Connect LocalRelations.
... more details TBD ...
### Why are the changes needed?
It fixes a stability issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests.
A Connect Planner Test TBD
Longevity tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53452 from hvanhovell/fix-arrow-local-relations.
Authored-by: Herman van Hövell <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit c36b7e58d0422a13228252657e4cff26a762a228)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../apache/spark/sql/util}/CloseableIterator.scala | 2 +-
.../sql/util}/ConcatenatingArrowStreamReader.scala | 8 +-
.../connect/client/arrow/ArrowEncoderSuite.scala | 2 +-
.../apache/spark/sql/connect/SparkSession.scala | 4 +-
.../sql/connect/StreamingQueryListenerBus.scala | 2 +-
.../client/CustomSparkConnectBlockingStub.scala | 1 +
.../ExecutePlanResponseReattachableIterator.scala | 1 +
.../connect/client/GrpcExceptionConverter.scala | 1 +
.../sql/connect/client/GrpcRetryHandler.scala | 1 +
.../sql/connect/client/ResponseValidator.scala | 1 +
.../sql/connect/client/SparkConnectClient.scala | 1 +
.../spark/sql/connect/client/SparkResult.scala | 4 +-
.../connect/client/arrow/ArrowDeserializer.scala | 2 +-
.../sql/connect/client/arrow/ArrowSerializer.scala | 3 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 38 +++------
.../spark/sql/connect/SparkConnectServerTest.scala | 3 +-
.../sql/execution/arrow/ArrowConverters.scala | 96 +++++++++-------------
.../sql/execution/arrow/ArrowConvertersSuite.scala | 16 ++--
18 files changed, 83 insertions(+), 103 deletions(-)
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
similarity index 97%
rename from
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
rename to
sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
index 9de585503a50..dc38c75d3ce7 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.util
private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable
{ self =>
def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with
AutoCloseable {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
similarity index 95%
rename from
sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
rename to
sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
index 90963c831c25..5de53a568a7d 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client.arrow
+package org.apache.spark.sql.util
import java.io.{InputStream, IOException}
import java.nio.channels.Channels
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.types.pojo.Schema
* closes its messages when it consumes them. In order to prevent that from
happening in
* non-destructive mode we clone the messages before passing them to the
reading logic.
*/
-class ConcatenatingArrowStreamReader(
+private[sql] class ConcatenatingArrowStreamReader(
allocator: BufferAllocator,
input: Iterator[AbstractMessageIterator],
destructive: Boolean)
@@ -128,7 +128,7 @@ class ConcatenatingArrowStreamReader(
override def closeReadSource(): Unit = ()
}
-trait AbstractMessageIterator extends Iterator[ArrowMessage] {
+private[sql] trait AbstractMessageIterator extends Iterator[ArrowMessage] {
def schema: Schema
def bytesRead: Long
}
@@ -137,7 +137,7 @@ trait AbstractMessageIterator extends
Iterator[ArrowMessage] {
* Decode an Arrow IPC stream into individual messages. Please note that this
iterator MUST have a
* valid IPC stream as its input, otherwise construction will fail.
*/
-class MessageIterator(input: InputStream, allocator: BufferAllocator)
+private[sql] class MessageIterator(input: InputStream, allocator:
BufferAllocator)
extends AbstractMessageIterator {
private[this] val in = new ReadChannel(Channels.newChannel(input))
private[this] val reader = new MessageChannelReader(in, allocator)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index d24369ff5fc7..52a503d62601 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -41,10 +41,10 @@ import
org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._
import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
import org.apache.spark.sql.connect.test.ConnectFunSuite
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType,
Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata,
SQLUserDefinedType, StringType, StructType, UserDefinedType,
YearMonthIntervalType}
+import org.apache.spark.sql.util.CloseableIterator
import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.{MaybeNull, SparkStringUtils}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index daa2cc2001e4..42dd1a2b9979 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -51,13 +51,13 @@ import
org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator,
SparkConnectClient, SparkResult}
+import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient,
SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf,
SubqueryExpression}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ExecutionListenerManager
+import org.apache.spark.sql.util.{CloseableIterator, ExecutionListenerManager}
import org.apache.spark.util.ArrayImplicits._
/**
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
index 52b0ea24e9e3..754823835146 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
@@ -23,9 +23,9 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan,
StreamingQueryEventType}
import org.apache.spark.internal.{Logging, LogKeys}
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener.{Event,
QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
+import org.apache.spark.sql.util.CloseableIterator
class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 913f068fcf34..715da0df7349 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._
import io.grpc.ManagedChannel
import org.apache.spark.connect.proto._
+import org.apache.spark.sql.util.CloseableIterator
private[connect] class CustomSparkConnectBlockingStub(
channel: ManagedChannel,
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index f3c13c9c2c4d..131a2e77cc43 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -28,6 +28,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryException
+import org.apache.spark.sql.util.WrappedCloseableIterator
/**
* Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index d3dae47f4c47..7e0b0949fcf1 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -35,6 +35,7 @@ import
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
import org.apache.spark.util.ArrayImplicits._
/**
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 3f4558ee97da..d92dc902fedc 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -23,6 +23,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{ERROR, NUM_RETRY, POLICY,
RETRY_WAIT_TIME}
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
private[sql] class GrpcRetryHandler(
private val policies: Seq[RetryPolicy],
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
index 03548120457f..6cf39b8d1879 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
@@ -23,6 +23,7 @@ import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
// This is common logic to be shared between different stub instances to keep
the server-side
// session id and to validate responses as seen by the client.
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 1a7d062470e1..5d36fc45f948 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -40,6 +40,7 @@ import org.apache.spark.internal.LogKeys.{ERROR, RATIO, SIZE,
TIME}
import org.apache.spark.sql.connect.RuntimeConfig
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.util.CloseableIterator
import org.apache.spark.util.SparkSystemUtils
/**
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 43265e55a0ca..4199801d8505 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -35,10 +35,10 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder,
UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator,
ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
+import org.apache.spark.sql.connect.client.arrow.ArrowDeserializingIterator
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{AbstractMessageIterator, ArrowUtils,
CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 8d5811dda8f3..82029025a7f0 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors}
import org.apache.spark.sql.types.Decimal
+import org.apache.spark.sql.util.{CloseableIterator,
ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.unsafe.types.VariantVal
/**
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 73c9a991ab6a..d547c81afe5a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -38,10 +38,9 @@ import
org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils,
SparkIntervalUtils}
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types.Decimal
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator}
import org.apache.spark.unsafe.types.VariantVal
/**
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9af2e7cb4661..9cbb760f6cc0 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -30,7 +30,7 @@ import com.google.protobuf.{Any => ProtoAny, ByteString,
Message}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
-import org.apache.spark.{SparkClassNotFoundException, SparkEnv,
SparkException, TaskContext}
+import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
@@ -1492,9 +1492,12 @@ class SparkConnectPlanner(
}
if (rel.hasData) {
- val (rows, structType) =
- ArrowConverters.fromIPCStream(rel.getData.toByteArray,
TaskContext.get())
- buildLocalRelationFromRows(rows, structType, Option(schema))
+ val (rows, structType) =
ArrowConverters.fromIPCStream(rel.getData.toByteArray)
+ try {
+ buildLocalRelationFromRows(rows, structType, Option(schema))
+ } finally {
+ rows.close()
+ }
} else {
if (schema == null) {
throw InvalidInputErrors.schemaRequiredForLocalRelation()
@@ -1565,28 +1568,13 @@ class SparkConnectPlanner(
}
// Load and combine all batches
- var combinedRows: Iterator[InternalRow] = Iterator.empty
- var structType: StructType = null
-
- for ((dataHash, batchIndex) <- dataHashes.zipWithIndex) {
- val dataBytes = readChunkedCachedLocalRelationBlock(dataHash)
- val (batchRows, batchStructType) =
- ArrowConverters.fromIPCStream(dataBytes, TaskContext.get())
-
- // For the first batch, set the schema; for subsequent batches, verify
compatibility
- if (batchIndex == 0) {
- structType = batchStructType
- combinedRows = batchRows
-
- } else {
- if (batchStructType != structType) {
- throw
InvalidInputErrors.chunkedCachedLocalRelationChunksWithDifferentSchema()
- }
- combinedRows = combinedRows ++ batchRows
- }
+ val (rows, structType) =
+
ArrowConverters.fromIPCStream(dataHashes.iterator.map(readChunkedCachedLocalRelationBlock))
+ try {
+ buildLocalRelationFromRows(rows, structType, Option(schema))
+ } finally {
+ rows.close()
}
-
- buildLocalRelationFromRows(combinedRows, structType, Option(schema))
}
private def toStructTypeOrWrap(dt: DataType): StructType = dt match {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 7b9052bb9d2c..77ede8e852e8 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.classic
import org.apache.spark.sql.connect
-import org.apache.spark.sql.connect.client.{CloseableIterator,
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator,
RetryPolicy, SparkConnectClient, SparkConnectStubState}
+import org.apache.spark.sql.connect.client.{CustomSparkConnectBlockingStub,
ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient,
SparkConnectStubState}
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect
@@ -38,6 +38,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey,
SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.util.CloseableIterator
/**
* Base class and utilities for a test suite that starts and tests the real
SparkConnectService
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 8b031af14e8b..e227a5852872 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -22,6 +22,7 @@ import java.nio.channels.{Channels, ReadableByteChannel}
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
import org.apache.arrow.flatbuf.MessageHeader
@@ -42,12 +43,11 @@ import
org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator,
ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch,
ColumnVector}
import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils}
import org.apache.spark.util.ArrayImplicits._
-
/**
* Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow
stream format.
*/
@@ -297,50 +297,33 @@ private[sql] object ArrowConverters extends Logging {
* @param context Task Context for Spark
*/
private[sql] class InternalRowIteratorFromIPCStream(
- input: Array[Byte],
- context: TaskContext) extends Iterator[InternalRow] {
-
- // Keep all the resources we have opened in order, should be closed
- // in reverse order finally.
- private val resources = new ArrayBuffer[AutoCloseable]()
+ ipcStreams: Iterator[Array[Byte]],
+ context: TaskContext)
+ extends CloseableIterator[InternalRow] {
// Create an allocator used for all Arrow related memory.
protected val allocator: BufferAllocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}",
0,
Long.MaxValue)
- resources.append(allocator)
- private val reader = try {
- new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
- } catch {
- case e: Exception =>
- closeAll(resources.toSeq.reverse: _*)
- throw new IllegalArgumentException(
- s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
- }
- resources.append(reader)
-
- private val root: VectorSchemaRoot = try {
- reader.getVectorSchemaRoot
- } catch {
- case e: Exception =>
- closeAll(resources.toSeq.reverse: _*)
- throw new IllegalArgumentException(
- s"Failed to read schema from IPC stream: ${e.getMessage}", e)
+ private val reader = {
+ val messages = ipcStreams.map { bytes =>
+ new MessageIterator(new ByteArrayInputStream(bytes), allocator)
+ }
+ new ConcatenatingArrowStreamReader(allocator, messages, destructive =
true)
}
- resources.append(root)
val schema: StructType = try {
- ArrowUtils.fromArrowSchema(root.getSchema)
+ ArrowUtils.fromArrowSchema(reader.getVectorSchemaRoot.getSchema)
} catch {
- case e: Exception =>
- closeAll(resources.toSeq.reverse: _*)
- throw new IllegalArgumentException(s"Failed to convert Arrow schema:
${e.getMessage}", e)
+ case NonFatal(e) =>
+ // Since this triggers a read (which involves allocating buffers) we
have to clean-up.
+ close()
+ throw e
}
- // TODO: wrap in exception
- private var rowIterator: Iterator[InternalRow] =
vectorSchemaRootToIter(root)
+ private var rowIterator: Iterator[InternalRow] = Iterator.empty
// Metrics to track batch processing
private var _batchesLoaded: Int = 0
@@ -348,7 +331,7 @@ private[sql] object ArrowConverters extends Logging {
if (context != null) {
context.addTaskCompletionListener[Unit] { _ =>
- closeAll(resources.toSeq.reverse: _*)
+ close()
}
}
@@ -356,28 +339,17 @@ private[sql] object ArrowConverters extends Logging {
def batchesLoaded: Int = _batchesLoaded
def totalRowsProcessed: Long = _totalRowsProcessed
- // Loads the next batch from the Arrow reader and returns true or
- // false if the next batch could be loaded.
- private def loadNextBatch(): Boolean = {
- if (reader.loadNextBatch()) {
- rowIterator = vectorSchemaRootToIter(root)
- _batchesLoaded += 1
- true
- } else {
- false
- }
- }
-
override def hasNext: Boolean = {
- if (rowIterator.hasNext) {
- true
- } else {
- if (!loadNextBatch()) {
- false
+ while (!rowIterator.hasNext) {
+ if (reader.loadNextBatch()) {
+ rowIterator = vectorSchemaRootToIter(reader.getVectorSchemaRoot)
+ _batchesLoaded += 1
} else {
- hasNext
+ close()
+ return false
}
}
+ true
}
override def next(): InternalRow = {
@@ -387,6 +359,10 @@ private[sql] object ArrowConverters extends Logging {
_totalRowsProcessed += 1
rowIterator.next()
}
+
+ override def close(): Unit = {
+ closeAll(reader, allocator)
+ }
}
/**
@@ -512,15 +488,21 @@ private[sql] object ArrowConverters extends Logging {
* one schema and a varying number of record batches. Returns an iterator
over the
* created InternalRow.
*/
- private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
- (Iterator[InternalRow], StructType) = {
- fromIPCStreamWithIterator(input, context)
+ private[sql] def fromIPCStream(input: Array[Byte]):
+ (CloseableIterator[InternalRow], StructType) = {
+ fromIPCStream(Iterator.single(input))
+ }
+
+ private[sql] def fromIPCStream(inputs: Iterator[Array[Byte]]):
+ (CloseableIterator[InternalRow], StructType) = {
+ val iterator = new InternalRowIteratorFromIPCStream(inputs, null)
+ (iterator, iterator.schema)
}
// Overloaded method for tests to access the iterator with metrics
private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context:
TaskContext):
- (InternalRowIteratorFromIPCStream, StructType) = {
- val iterator = new InternalRowIteratorFromIPCStream(input, context)
+ (InternalRowIteratorFromIPCStream, StructType) = {
+ val iterator = new
InternalRowIteratorFromIPCStream(Iterator.single(input), context)
(iterator, iterator.schema)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index ccf6b63eb5de..f58a5b7ebd6a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1626,7 +1626,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
writer.end()
}
- val (outputRowIter, outputSchema) =
ArrowConverters.fromIPCStream(out.toByteArray, ctx)
+ val (outputRowIter, outputSchema) = ArrowConverters.
+ fromIPCStreamWithIterator(out.toByteArray, ctx)
assert(outputSchema == schema)
val res = outputRowIter.zipWithIndex.map { case (row, i) =>
assert(row.getInt(0) == i)
@@ -1663,7 +1664,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
writer.end()
}
- val (outputRowIter, outputSchema) =
ArrowConverters.fromIPCStream(out.toByteArray, ctx)
+ val (outputRowIter, outputSchema) = ArrowConverters
+ .fromIPCStreamWithIterator(out.toByteArray, ctx)
assert(outputSchema == schema)
val outputRows = outputRowIter.zipWithIndex.map { case (row, i) =>
assert(row.getInt(0) == i)
@@ -1760,7 +1762,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val invalidData = Array[Byte](1, 2, 3, 4, 5)
intercept[Exception] {
- ArrowConverters.fromIPCStream(invalidData, ctx)
+ ArrowConverters.fromIPCStreamWithIterator(invalidData, ctx)
}
}
@@ -1769,7 +1771,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val emptyData = Array.empty[Byte]
intercept[Exception] {
- ArrowConverters.fromIPCStream(emptyData, ctx)
+ ArrowConverters.fromIPCStreamWithIterator(emptyData, ctx)
}
}
@@ -1790,7 +1792,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
// Test with null context - should still work but won't have cleanup
registration
val proj = UnsafeProjection.create(schema)
- val (outputRowIter, outputSchema) =
ArrowConverters.fromIPCStream(out.toByteArray, null)
+ val (outputRowIter, outputSchema) = ArrowConverters.
+ fromIPCStreamWithIterator(out.toByteArray, null)
assert(outputSchema == schema)
val outputRows = outputRowIter.map(proj(_).copy()).toList
assert(outputRows.length == inputRows.length)
@@ -1884,7 +1887,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
writer.end()
}
- val (outputRowIter, outputSchema) =
ArrowConverters.fromIPCStream(out.toByteArray, ctx)
+ val (outputRowIter, outputSchema) = ArrowConverters.
+ fromIPCStreamWithIterator(out.toByteArray, ctx)
val proj = UnsafeProjection.create(schema)
assert(outputSchema == schema)
val outputRows = outputRowIter.map(proj(_).copy()).toList
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]