This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 09a2cadc1fb4 [SPARK-54696][CONNECT] Clean-up Arrow Buffers - follow-up
09a2cadc1fb4 is described below
commit 09a2cadc1fb4c162565bb70610867d6f1aa10dee
Author: Herman van Hövell <[email protected]>
AuthorDate: Fri Dec 19 23:05:23 2025 +0800
[SPARK-54696][CONNECT] Clean-up Arrow Buffers - follow-up
### What changes were proposed in this pull request?
There were a couple of ommissions in
https://github.com/apache/spark/commit/c36b7e58d0422a13228252657e4cff26a762a228
this PR addresses them. The following changes were made:
- Testing that arrow buffers are actually cleaned up when IPC stream
iterators are exhausted.
- Throw a proper error when there is a schema mismatch in between different
IPC streams.
- Tidy up some duplicate code in the SparkConnectPlanner.
### Why are the changes needed?
The previous PR was merged in a hurry. These things were missed.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
It adds tests for IPC stream iterators.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53480 from hvanhovell/SPARK-54696-follow-up-2.
Authored-by: Herman van Hövell <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/util/ConcatenatingArrowStreamReader.scala | 4 +++-
.../spark/sql/connect/planner/InvalidInputErrors.scala | 3 ---
.../apache/spark/sql/execution/arrow/ArrowConverters.scala | 4 +++-
.../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 12 ++++++++++++
4 files changed, 18 insertions(+), 5 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
index 5de53a568a7d..2e5706fe4dcc 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
@@ -25,6 +25,8 @@ import org.apache.arrow.vector.ipc.{ArrowReader, ReadChannel}
import org.apache.arrow.vector.ipc.message.{ArrowDictionaryBatch,
ArrowMessage, ArrowRecordBatch, MessageChannelReader, MessageResult,
MessageSerializer}
import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.SparkException
+
/**
* An [[ArrowReader]] that concatenates multiple [[MessageIterator]]s into a
single stream. Each
* iterator represents a single IPC stream. The concatenated streams all must
have the same
@@ -62,7 +64,7 @@ private[sql] class ConcatenatingArrowStreamReader(
totalBytesRead += current.bytesRead
current = input.next()
if (current.schema != getVectorSchemaRoot.getSchema) {
- throw new IllegalStateException()
+ throw SparkException.internalError("IPC Streams have different
schemas.")
}
}
if (current.hasNext) {
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
index fcef696c88af..eb4df9673e59 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
@@ -96,9 +96,6 @@ object InvalidInputErrors {
def chunkedCachedLocalRelationWithoutData(): InvalidPlanInput =
InvalidPlanInput("ChunkedCachedLocalRelation should contain data.")
- def chunkedCachedLocalRelationChunksWithDifferentSchema(): InvalidPlanInput =
- InvalidPlanInput("ChunkedCachedLocalRelation data chunks have different
schema.")
-
def schemaRequiredForLocalRelation(): InvalidPlanInput =
InvalidPlanInput("Schema for LocalRelation is required when the input data
is not provided.")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 799996126e42..b1e8217ff257 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -313,7 +313,7 @@ private[sql] object ArrowConverters extends Logging {
new ConcatenatingArrowStreamReader(allocator, messages, destructive =
true)
}
- val schema: StructType = try {
+ lazy val schema: StructType = try {
ArrowUtils.fromArrowSchema(reader.getVectorSchemaRoot.getSchema)
} catch {
case NonFatal(e) =>
@@ -337,6 +337,8 @@ private[sql] object ArrowConverters extends Logging {
// Public accessors for metrics
def batchesLoaded: Int = _batchesLoaded
def totalRowsProcessed: Long = _totalRowsProcessed
+ def allocatedMemory: Long = allocator.getAllocatedMemory
+ def peakMemoryAllocation: Long = allocator.getPeakMemoryAllocation
override def hasNext: Boolean = {
while (!rowIterator.hasNext) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index f58a5b7ebd6a..95cd97c2c742 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -1795,7 +1795,10 @@ class ArrowConvertersSuite extends SharedSparkSession {
val (outputRowIter, outputSchema) = ArrowConverters.
fromIPCStreamWithIterator(out.toByteArray, null)
assert(outputSchema == schema)
+ assert(outputRowIter.peakMemoryAllocation == 0)
val outputRows = outputRowIter.map(proj(_).copy()).toList
+ assert(outputRowIter.peakMemoryAllocation > 0)
+ assert(outputRowIter.allocatedMemory == 0)
assert(outputRows.length == inputRows.length)
outputRows.zipWithIndex.foreach { case (row, i) =>
assert(row.getInt(0) == i)
@@ -1820,6 +1823,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val (iterator, outputSchema) =
ArrowConverters.fromIPCStreamWithIterator(out.toByteArray, ctx)
assert(outputSchema == schema)
+ assert(iterator.peakMemoryAllocation == 0)
// Initially no batches loaded
assert(iterator.batchesLoaded == 0)
@@ -1837,6 +1841,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
// Consume all rows
val proj = UnsafeProjection.create(schema)
val outputRows = iterator.map(proj(_).copy()).toList
+ assert(iterator.peakMemoryAllocation > 0)
+ assert(iterator.allocatedMemory == 0)
assert(outputRows.length == inputRows.length)
outputRows.zipWithIndex.foreach { case (row, i) =>
assert(row.getInt(0) == i)
@@ -1889,9 +1895,12 @@ class ArrowConvertersSuite extends SharedSparkSession {
val (outputRowIter, outputSchema) = ArrowConverters.
fromIPCStreamWithIterator(out.toByteArray, ctx)
+ assert(outputRowIter.peakMemoryAllocation == 0)
val proj = UnsafeProjection.create(schema)
assert(outputSchema == schema)
val outputRows = outputRowIter.map(proj(_).copy()).toList
+ assert(outputRowIter.peakMemoryAllocation > 0)
+ assert(outputRowIter.allocatedMemory == 0)
assert(outputRows.length == inputRows.length)
outputRows.zipWithIndex.foreach { case (row, i) =>
@@ -1926,6 +1935,7 @@ class ArrowConvertersSuite extends SharedSparkSession {
val (iterator, outputSchema) =
ArrowConverters.fromIPCStreamWithIterator(out.toByteArray, ctx)
assert(outputSchema == schema)
+ assert(iterator.peakMemoryAllocation == 0)
// Initially no batches loaded
assert(iterator.batchesLoaded == 0)
@@ -1945,6 +1955,8 @@ class ArrowConvertersSuite extends SharedSparkSession {
val remainingRows = iterator.toList
val totalConsumed = firstBatch.length + remainingRows.length
assert(totalConsumed == inputRows.length)
+ assert(iterator.peakMemoryAllocation > 0)
+ assert(iterator.allocatedMemory == 0)
// Final metrics should show all batches loaded
val expectedBatches = Math.ceil(inputRows.length.toDouble /
batchSize).toInt
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]