This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a5e866f7ff57 [SPARK-54132][SQL][TESTS] Cover HashedRelation#close in 
HashedRelationSuite
a5e866f7ff57 is described below

commit a5e866f7ff57ef5f662d6b8882c680fcbcc5ac4a
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Nov 4 02:47:54 2025 +0800

    [SPARK-54132][SQL][TESTS] Cover HashedRelation#close in HashedRelationSuite
    
    ### What changes were proposed in this pull request?
    
    Add the following code in `HashedRelationSuite`, to cover the API 
`HashedRelation#close` in the test suite.
    
    ```scala
      protected override def afterEach(): Unit = {
        super.afterEach()
        assert(umm.executionMemoryUsed === 0)
      }
    ```
    
    ### Why are the changes needed?
    
    Doing this will:
    
    1. Ensure `HashedRelation#close` is called in test code, to lower memory 
footprint and avoid memory leak when executing tests.
    2. Ensure implementations of `HashedRelation#close` free the allocated 
memory blocks correctly.
    
    It's an individual effort to improve the test quality, but also a 
prerequisite task for https://github.com/apache/spark/pull/52817.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    It's a test PR.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52830 from zhztheplayer/wip-54132.
    
    Authored-by: Hongze Zhang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/execution/joins/HashedRelationSuite.scala  | 90 ++++++++++------------
 1 file changed, 39 insertions(+), 51 deletions(-)

diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 6da5e0b1a123..b88a76bbfb57 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -40,14 +40,13 @@ import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.collection.CompactBuffer
 
 class HashedRelationSuite extends SharedSparkSession {
+  val umm = new UnifiedMemoryManager(
+    new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
+    Long.MaxValue,
+    Long.MaxValue / 2,
+    1)
 
-  val mm = new TaskMemoryManager(
-    new UnifiedMemoryManager(
-      new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
-      Long.MaxValue,
-      Long.MaxValue / 2,
-      1),
-    0)
+  val mm = new TaskMemoryManager(umm, 0)
 
   val rand = new Random(100)
 
@@ -64,6 +63,11 @@ class HashedRelationSuite extends SharedSparkSession {
   val sparseRows = sparseArray.map(i => 
projection(InternalRow(i.toLong)).copy())
   val randomRows = randomArray.map(i => 
projection(InternalRow(i.toLong)).copy())
 
+  protected override def afterEach(): Unit = {
+    super.afterEach()
+    assert(umm.executionMemoryUsed === 0)
+  }
+
   test("UnsafeHashedRelation") {
     val schema = StructType(StructField("a", IntegerType, true) :: Nil)
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), 
InternalRow(2))
@@ -87,6 +91,7 @@ class HashedRelationSuite extends SharedSparkSession {
     val out = new ObjectOutputStream(os)
     hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
     out.flush()
+    hashed.close()
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
     val hashed2 = new UnsafeHashedRelation()
     hashed2.readExternal(in)
@@ -108,19 +113,13 @@ class HashedRelationSuite extends SharedSparkSession {
   }
 
   test("test serialization empty hash map") {
-    val taskMemoryManager = new TaskMemoryManager(
-      new UnifiedMemoryManager(
-        new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
-        Long.MaxValue,
-        Long.MaxValue / 2,
-        1),
-      0)
-    val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1)
+    val binaryMap = new BytesToBytesMap(mm, 1, 1)
     val os = new ByteArrayOutputStream()
     val out = new ObjectOutputStream(os)
     val hashed = new UnsafeHashedRelation(1, 1, binaryMap)
     hashed.writeExternal(out)
     out.flush()
+    hashed.close()
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
     val hashed2 = new UnsafeHashedRelation()
     hashed2.readExternal(in)
@@ -149,9 +148,10 @@ class HashedRelationSuite extends SharedSparkSession {
       assert(row.getLong(0) === i)
       assert(row.getInt(1) === i + 1)
     }
+    longRelation.close()
 
     val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, 
key, 100, mm)
-        .asInstanceOf[LongHashedRelation]
+      .asInstanceOf[LongHashedRelation]
     assert(!longRelation2.keyIsUnique)
     (0 until 100).foreach { i =>
       val rows = longRelation2.get(i).toArray
@@ -166,6 +166,7 @@ class HashedRelationSuite extends SharedSparkSession {
     val out = new ObjectOutputStream(os)
     longRelation2.writeExternal(out)
     out.flush()
+    longRelation2.close()
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
     val relation = new LongHashedRelation()
     relation.readExternal(in)
@@ -181,19 +182,12 @@ class HashedRelationSuite extends SharedSparkSession {
   }
 
   test("LongToUnsafeRowMap with very wide range") {
-    val taskMemoryManager = new TaskMemoryManager(
-      new UnifiedMemoryManager(
-        new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
-        Long.MaxValue,
-        Long.MaxValue / 2,
-        1),
-      0)
     val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, 
false)))
 
     {
       // SPARK-16740
       val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
-      val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+      val map = new LongToUnsafeRowMap(mm, 1)
       keys.foreach { k =>
         map.append(k, unsafeProj(InternalRow(k)))
       }
@@ -210,7 +204,7 @@ class HashedRelationSuite extends SharedSparkSession {
     {
       // SPARK-16802
       val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
-      val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+      val map = new LongToUnsafeRowMap(mm, 1)
       keys.foreach { k =>
         map.append(k, unsafeProj(InternalRow(k)))
       }
@@ -226,20 +220,13 @@ class HashedRelationSuite extends SharedSparkSession {
   }
 
   test("LongToUnsafeRowMap with random keys") {
-    val taskMemoryManager = new TaskMemoryManager(
-      new UnifiedMemoryManager(
-        new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
-        Long.MaxValue,
-        Long.MaxValue / 2,
-        1),
-      0)
     val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, 
false)))
 
     val N = 1000000
     val rand = new Random
     val keys = (0 to N).map(x => rand.nextLong()).toArray
 
-    val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
+    val map = new LongToUnsafeRowMap(mm, 10)
     keys.foreach { k =>
       map.append(k, unsafeProj(InternalRow(k)))
     }
@@ -249,8 +236,9 @@ class HashedRelationSuite extends SharedSparkSession {
     val out = new ObjectOutputStream(os)
     map.writeExternal(out)
     out.flush()
+    map.free()
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
-    val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
+    val map2 = new LongToUnsafeRowMap(mm, 1)
     map2.readExternal(in)
 
     val row = unsafeProj(InternalRow(0L)).copy()
@@ -276,19 +264,12 @@ class HashedRelationSuite extends SharedSparkSession {
       }
       i += 1
     }
-    map.free()
+    map2.free()
   }
 
   test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
-    val taskMemoryManager = new TaskMemoryManager(
-      new UnifiedMemoryManager(
-        new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
-        Long.MaxValue,
-        Long.MaxValue / 2,
-        1),
-      0)
     val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
-    val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+    val map = new LongToUnsafeRowMap(mm, 1)
 
     val key = 0L
     // the page array is initialized with length 1 << 17 (1M bytes),
@@ -343,6 +324,7 @@ class HashedRelationSuite extends SharedSparkSession {
     val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), 
i + 1)).copy())
     val longRelation = LongHashedRelation(rows.iterator ++ rows.iterator, key, 
100, mm)
     val longRelation2 = 
ser.deserialize[LongHashedRelation](ser.serialize(longRelation))
+    longRelation.close()
     (0 until 100).foreach { i =>
       val rows = longRelation2.get(i).toArray
       assert(rows.length === 2)
@@ -359,6 +341,7 @@ class HashedRelationSuite extends SharedSparkSession {
     unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
     out.flush()
     val unsafeHashed2 = 
ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
+    unsafeHashed.close()
     val os2 = new ByteArrayOutputStream()
     val out2 = new ObjectOutputStream(os2)
     unsafeHashed2.writeExternal(out2)
@@ -398,6 +381,7 @@ class HashedRelationSuite extends SharedSparkSession {
     thread2.join()
 
     val unsafeHashed2 = 
ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
+    unsafeHashed.close()
     val os2 = new ByteArrayOutputStream()
     val out2 = new ObjectOutputStream(os2)
     unsafeHashed2.writeExternal(out2)
@@ -452,18 +436,21 @@ class HashedRelationSuite extends SharedSparkSession {
     val hashedRelation = UnsafeHashedRelation(contiguousRows.iterator, 
singleKey, 1, mm)
     val keyIterator = hashedRelation.keys()
     assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
+    hashedRelation.close()
   }
 
   test("UnsafeHashedRelation: key set iterator on a sparse array of keys") {
     val hashedRelation = UnsafeHashedRelation(sparseRows.iterator, singleKey, 
1, mm)
     val keyIterator = hashedRelation.keys()
     assert(keyIterator.map(key => key.getLong(0)).toArray === sparseArray)
+    hashedRelation.close()
   }
 
   test("LongHashedRelation: key set iterator on a contiguous array of keys") {
     val longRelation = LongHashedRelation(contiguousRows.iterator, singleKey, 
1, mm)
     val keyIterator = longRelation.keys()
     assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
+    longRelation.close()
   }
 
   test("LongToUnsafeRowMap: key set iterator on a contiguous array of keys") {
@@ -478,6 +465,7 @@ class HashedRelationSuite extends SharedSparkSession {
     rowMap.optimize()
     keyIterator = rowMap.keys()
     assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
+    rowMap.free()
   }
 
   test("LongToUnsafeRowMap: key set iterator on a sparse array with 
equidistant keys") {
@@ -490,6 +478,7 @@ class HashedRelationSuite extends SharedSparkSession {
     rowMap.optimize()
     keyIterator = rowMap.keys()
     assert(keyIterator.map(_.getLong(0)).toArray === sparseArray)
+    rowMap.free()
   }
 
   test("LongToUnsafeRowMap: key set iterator on an array with a single key") {
@@ -530,6 +519,7 @@ class HashedRelationSuite extends SharedSparkSession {
       buffer.append(keyIterator.next().getLong(0))
     }
     assert(buffer === randomArray)
+    rowMap.free()
   }
 
   test("LongToUnsafeRowMap: no explicit hasNext calls on the key iterator") {
@@ -560,6 +550,7 @@ class HashedRelationSuite extends SharedSparkSession {
       buffer.append(keyIterator.next().getLong(0))
     }
     assert(buffer === randomArray)
+    rowMap.free()
   }
 
   test("LongToUnsafeRowMap: call hasNext at the end of the iterator") {
@@ -577,6 +568,7 @@ class HashedRelationSuite extends SharedSparkSession {
     assert(keyIterator.map(key => key.getLong(0)).toArray === sparseArray)
     assert(keyIterator.hasNext == false)
     assert(keyIterator.hasNext == false)
+    rowMap.free()
   }
 
   test("LongToUnsafeRowMap: random sequence of hasNext and next() calls on the 
key iterator") {
@@ -607,6 +599,7 @@ class HashedRelationSuite extends SharedSparkSession {
       }
     }
     assert(buffer === randomArray)
+    rowMap.free()
   }
 
   test("HashJoin: packing and unpacking with the same key type in a LongType") 
{
@@ -661,6 +654,7 @@ class HashedRelationSuite extends SharedSparkSession {
     assert(hashed.keys().isEmpty)
     assert(hashed.keyIsUnique)
     assert(hashed.estimatedSize == 0)
+    hashed.close()
   }
 
   test("SPARK-32399: test methods related to key index") {
@@ -739,20 +733,14 @@ class HashedRelationSuite extends SharedSparkSession {
         val actualValues = row.map(_._2.getInt(1))
         assert(actualValues === expectedValues)
     }
+    unsafeRelation.close()
   }
 
   test("LongToUnsafeRowMap support ignoresDuplicatedKey") {
-    val taskMemoryManager = new TaskMemoryManager(
-      new UnifiedMemoryManager(
-        new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
-        Long.MaxValue,
-        Long.MaxValue / 2,
-        1),
-      0)
     val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, 
false)))
     val keys = Seq(1L, 1L, 1L)
     Seq(true, false).foreach { ignoresDuplicatedKey =>
-      val map = new LongToUnsafeRowMap(taskMemoryManager, 1, 
ignoresDuplicatedKey)
+      val map = new LongToUnsafeRowMap(mm, 1, ignoresDuplicatedKey)
       keys.foreach { k =>
         map.append(k, unsafeProj(InternalRow(k)))
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to