This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a5e866f7ff57 [SPARK-54132][SQL][TESTS] Cover HashedRelation#close in
HashedRelationSuite
a5e866f7ff57 is described below
commit a5e866f7ff57ef5f662d6b8882c680fcbcc5ac4a
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Nov 4 02:47:54 2025 +0800
[SPARK-54132][SQL][TESTS] Cover HashedRelation#close in HashedRelationSuite
### What changes were proposed in this pull request?
Add the following code in `HashedRelationSuite`, to cover the API
`HashedRelation#close` in the test suite.
```scala
protected override def afterEach(): Unit = {
super.afterEach()
assert(umm.executionMemoryUsed === 0)
}
```
### Why are the changes needed?
Doing this will:
1. Ensure `HashedRelation#close` is called in test code, to lower memory
footprint and avoid memory leak when executing tests.
2. Ensure implementations of `HashedRelation#close` free the allocated
memory blocks correctly.
It's an individual effort to improve the test quality, but also a
prerequisite task for https://github.com/apache/spark/pull/52817.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
It's a test PR.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52830 from zhztheplayer/wip-54132.
Authored-by: Hongze Zhang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/execution/joins/HashedRelationSuite.scala | 90 ++++++++++------------
1 file changed, 39 insertions(+), 51 deletions(-)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 6da5e0b1a123..b88a76bbfb57 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -40,14 +40,13 @@ import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SharedSparkSession {
+ val umm = new UnifiedMemoryManager(
+ new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
+ Long.MaxValue,
+ Long.MaxValue / 2,
+ 1)
- val mm = new TaskMemoryManager(
- new UnifiedMemoryManager(
- new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
- 1),
- 0)
+ val mm = new TaskMemoryManager(umm, 0)
val rand = new Random(100)
@@ -64,6 +63,11 @@ class HashedRelationSuite extends SharedSparkSession {
val sparseRows = sparseArray.map(i =>
projection(InternalRow(i.toLong)).copy())
val randomRows = randomArray.map(i =>
projection(InternalRow(i.toLong)).copy())
+ protected override def afterEach(): Unit = {
+ super.afterEach()
+ assert(umm.executionMemoryUsed === 0)
+ }
+
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2),
InternalRow(2))
@@ -87,6 +91,7 @@ class HashedRelationSuite extends SharedSparkSession {
val out = new ObjectOutputStream(os)
hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
+ hashed.close()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
@@ -108,19 +113,13 @@ class HashedRelationSuite extends SharedSparkSession {
}
test("test serialization empty hash map") {
- val taskMemoryManager = new TaskMemoryManager(
- new UnifiedMemoryManager(
- new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
- 1),
- 0)
- val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1)
+ val binaryMap = new BytesToBytesMap(mm, 1, 1)
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
val hashed = new UnsafeHashedRelation(1, 1, binaryMap)
hashed.writeExternal(out)
out.flush()
+ hashed.close()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
@@ -149,9 +148,10 @@ class HashedRelationSuite extends SharedSparkSession {
assert(row.getLong(0) === i)
assert(row.getInt(1) === i + 1)
}
+ longRelation.close()
val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator,
key, 100, mm)
- .asInstanceOf[LongHashedRelation]
+ .asInstanceOf[LongHashedRelation]
assert(!longRelation2.keyIsUnique)
(0 until 100).foreach { i =>
val rows = longRelation2.get(i).toArray
@@ -166,6 +166,7 @@ class HashedRelationSuite extends SharedSparkSession {
val out = new ObjectOutputStream(os)
longRelation2.writeExternal(out)
out.flush()
+ longRelation2.close()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val relation = new LongHashedRelation()
relation.readExternal(in)
@@ -181,19 +182,12 @@ class HashedRelationSuite extends SharedSparkSession {
}
test("LongToUnsafeRowMap with very wide range") {
- val taskMemoryManager = new TaskMemoryManager(
- new UnifiedMemoryManager(
- new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
- 1),
- 0)
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType,
false)))
{
// SPARK-16740
val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
- val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ val map = new LongToUnsafeRowMap(mm, 1)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
@@ -210,7 +204,7 @@ class HashedRelationSuite extends SharedSparkSession {
{
// SPARK-16802
val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
- val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ val map = new LongToUnsafeRowMap(mm, 1)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
@@ -226,20 +220,13 @@ class HashedRelationSuite extends SharedSparkSession {
}
test("LongToUnsafeRowMap with random keys") {
- val taskMemoryManager = new TaskMemoryManager(
- new UnifiedMemoryManager(
- new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
- 1),
- 0)
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType,
false)))
val N = 1000000
val rand = new Random
val keys = (0 to N).map(x => rand.nextLong()).toArray
- val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
+ val map = new LongToUnsafeRowMap(mm, 10)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
@@ -249,8 +236,9 @@ class HashedRelationSuite extends SharedSparkSession {
val out = new ObjectOutputStream(os)
map.writeExternal(out)
out.flush()
+ map.free()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
- val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ val map2 = new LongToUnsafeRowMap(mm, 1)
map2.readExternal(in)
val row = unsafeProj(InternalRow(0L)).copy()
@@ -276,19 +264,12 @@ class HashedRelationSuite extends SharedSparkSession {
}
i += 1
}
- map.free()
+ map2.free()
}
test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
- val taskMemoryManager = new TaskMemoryManager(
- new UnifiedMemoryManager(
- new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
- 1),
- 0)
val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
- val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ val map = new LongToUnsafeRowMap(mm, 1)
val key = 0L
// the page array is initialized with length 1 << 17 (1M bytes),
@@ -343,6 +324,7 @@ class HashedRelationSuite extends SharedSparkSession {
val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i),
i + 1)).copy())
val longRelation = LongHashedRelation(rows.iterator ++ rows.iterator, key,
100, mm)
val longRelation2 =
ser.deserialize[LongHashedRelation](ser.serialize(longRelation))
+ longRelation.close()
(0 until 100).foreach { i =>
val rows = longRelation2.get(i).toArray
assert(rows.length === 2)
@@ -359,6 +341,7 @@ class HashedRelationSuite extends SharedSparkSession {
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
val unsafeHashed2 =
ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
+ unsafeHashed.close()
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
unsafeHashed2.writeExternal(out2)
@@ -398,6 +381,7 @@ class HashedRelationSuite extends SharedSparkSession {
thread2.join()
val unsafeHashed2 =
ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
+ unsafeHashed.close()
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
unsafeHashed2.writeExternal(out2)
@@ -452,18 +436,21 @@ class HashedRelationSuite extends SharedSparkSession {
val hashedRelation = UnsafeHashedRelation(contiguousRows.iterator,
singleKey, 1, mm)
val keyIterator = hashedRelation.keys()
assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
+ hashedRelation.close()
}
test("UnsafeHashedRelation: key set iterator on a sparse array of keys") {
val hashedRelation = UnsafeHashedRelation(sparseRows.iterator, singleKey,
1, mm)
val keyIterator = hashedRelation.keys()
assert(keyIterator.map(key => key.getLong(0)).toArray === sparseArray)
+ hashedRelation.close()
}
test("LongHashedRelation: key set iterator on a contiguous array of keys") {
val longRelation = LongHashedRelation(contiguousRows.iterator, singleKey,
1, mm)
val keyIterator = longRelation.keys()
assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
+ longRelation.close()
}
test("LongToUnsafeRowMap: key set iterator on a contiguous array of keys") {
@@ -478,6 +465,7 @@ class HashedRelationSuite extends SharedSparkSession {
rowMap.optimize()
keyIterator = rowMap.keys()
assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
+ rowMap.free()
}
test("LongToUnsafeRowMap: key set iterator on a sparse array with
equidistant keys") {
@@ -490,6 +478,7 @@ class HashedRelationSuite extends SharedSparkSession {
rowMap.optimize()
keyIterator = rowMap.keys()
assert(keyIterator.map(_.getLong(0)).toArray === sparseArray)
+ rowMap.free()
}
test("LongToUnsafeRowMap: key set iterator on an array with a single key") {
@@ -530,6 +519,7 @@ class HashedRelationSuite extends SharedSparkSession {
buffer.append(keyIterator.next().getLong(0))
}
assert(buffer === randomArray)
+ rowMap.free()
}
test("LongToUnsafeRowMap: no explicit hasNext calls on the key iterator") {
@@ -560,6 +550,7 @@ class HashedRelationSuite extends SharedSparkSession {
buffer.append(keyIterator.next().getLong(0))
}
assert(buffer === randomArray)
+ rowMap.free()
}
test("LongToUnsafeRowMap: call hasNext at the end of the iterator") {
@@ -577,6 +568,7 @@ class HashedRelationSuite extends SharedSparkSession {
assert(keyIterator.map(key => key.getLong(0)).toArray === sparseArray)
assert(keyIterator.hasNext == false)
assert(keyIterator.hasNext == false)
+ rowMap.free()
}
test("LongToUnsafeRowMap: random sequence of hasNext and next() calls on the
key iterator") {
@@ -607,6 +599,7 @@ class HashedRelationSuite extends SharedSparkSession {
}
}
assert(buffer === randomArray)
+ rowMap.free()
}
test("HashJoin: packing and unpacking with the same key type in a LongType")
{
@@ -661,6 +654,7 @@ class HashedRelationSuite extends SharedSparkSession {
assert(hashed.keys().isEmpty)
assert(hashed.keyIsUnique)
assert(hashed.estimatedSize == 0)
+ hashed.close()
}
test("SPARK-32399: test methods related to key index") {
@@ -739,20 +733,14 @@ class HashedRelationSuite extends SharedSparkSession {
val actualValues = row.map(_._2.getInt(1))
assert(actualValues === expectedValues)
}
+ unsafeRelation.close()
}
test("LongToUnsafeRowMap support ignoresDuplicatedKey") {
- val taskMemoryManager = new TaskMemoryManager(
- new UnifiedMemoryManager(
- new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
- 1),
- 0)
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType,
false)))
val keys = Seq(1L, 1L, 1L)
Seq(true, false).foreach { ignoresDuplicatedKey =>
- val map = new LongToUnsafeRowMap(taskMemoryManager, 1,
ignoresDuplicatedKey)
+ val map = new LongToUnsafeRowMap(mm, 1, ignoresDuplicatedKey)
keys.foreach { k =>
map.append(k, unsafeProj(InternalRow(k)))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]