Repository: spark
Updated Branches:
  refs/heads/master 66924ffa6 -> 2e981b7bf


[SPARK-9531] [SQL] 
UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter

This pull request adds a destructAndCreateExternalSorter method to 
UnsafeFixedWidthAggregationMap. The new method does the following:

1. Creates a new external sorter UnsafeKVExternalSorter
2. Adds all the data into an in-memory sorter, sorts them
3. Spills the sorted in-memory data to disk

This method can be used to fallback to sort-based aggregation when under memory 
pressure.

The pull request also includes accounting fixes from JoshRosen.

TODOs (that can be done in follow-up PRs)
- [x] Address Josh's feedbacks from #7849
- [x] More documentation and test cases
- [x] Make sure we are doing memory accounting correctly with test cases (e.g. 
did we release the memory in BytesToBytesMap twice?)
- [ ] Look harder at possible memory leaks and exception handling
- [ ] Randomized tester for the KV sorter as well as the aggregation map

Author: Reynold Xin <[email protected]>
Author: Josh Rosen <[email protected]>

Closes #7860 from rxin/kvsorter and squashes the following commits:

986a58c [Reynold Xin] Bug fix.
599317c [Reynold Xin] Style fix and slightly more compact code.
fe7bd4e [Reynold Xin] Bug fixes.
fd71bef [Reynold Xin] Merge remote-tracking branch 
'josh/large-records-in-sql-sorter' into kvsorter-with-josh-fix
3efae38 [Reynold Xin] More fixes and documentation.
45f1b09 [Josh Rosen] Ensure that spill files are cleaned up
f6a9bd3 [Reynold Xin] Josh feedback.
9be8139 [Reynold Xin] Remove testSpillFrequency.
7cbe759 [Reynold Xin] [SPARK-9531][SQL] 
UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter.
ae4a8af [Josh Rosen] Detect leaked unsafe memory in UnsafeExternalSorterSuite.
52f9b06 [Josh Rosen] Detect ShuffleMemoryManager leaks in UnsafeExternalSorter.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2e981b7b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2e981b7b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2e981b7b

Branch: refs/heads/master
Commit: 2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f
Parents: 66924ff
Author: Reynold Xin <[email protected]>
Authored: Sun Aug 2 12:32:14 2015 -0700
Committer: Josh Rosen <[email protected]>
Committed: Sun Aug 2 12:32:14 2015 -0700

----------------------------------------------------------------------
 .../spark/unsafe/map/BytesToBytesMap.java       |  32 ++-
 .../unsafe/sort/UnsafeExternalSorter.java       | 197 ++++++++++++----
 .../unsafe/sort/UnsafeInMemorySorter.java       |   4 +
 .../unsafe/sort/UnsafeSorterSpillReader.java    |   3 +
 .../unsafe/sort/UnsafeSorterSpillWriter.java    |   4 +
 .../map/AbstractBytesToBytesMapSuite.java       |   7 +-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  65 +++--
 .../sql/catalyst/expressions/UnsafeRow.java     |   3 +-
 .../sql/execution/UnsafeExternalRowSorter.java  |   9 +-
 sql/core/pom.xml                                |   5 +
 .../UnsafeFixedWidthAggregationMap.java         | 103 +-------
 .../sql/execution/UnsafeKVExternalSorter.java   | 236 +++++++++++++++++++
 .../spark/sql/execution/SortPrefixUtils.scala   |  33 ++-
 .../org/apache/spark/sql/execution/sort.scala   |   4 +-
 .../execution/TestShuffleMemoryManager.scala    |  51 ++++
 .../UnsafeFixedWidthAggregationMapSuite.scala   | 124 +++++++---
 .../execution/UnsafeKVExternalSorterSuite.scala | 158 +++++++++++++
 17 files changed, 823 insertions(+), 215 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java 
b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index cf222b7..01a6608 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -39,14 +39,22 @@ import org.apache.spark.unsafe.memory.*;
 
 /**
  * An append-only hash map where keys and values are contiguous regions of 
bytes.
- * <p>
+ *
  * This is backed by a power-of-2-sized hash table, using quadratic probing 
with triangular numbers,
  * which is guaranteed to exhaust the space.
- * <p>
+ *
  * The map can support up to 2^29 keys. If the key cardinality is higher than 
this, you should
  * probably be using sorting instead of hashing for better cache locality.
- * <p>
- * This class is not thread safe.
+ *
+ * The key and values under the hood are stored together, in the following 
format:
+ *   Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in 
bytes) + 4
+ *   Bytes 4 to 8: len(k)
+ *   Bytes 8 to 8 + len(k): key data
+ *   Bytes 8 + len(k) to 8 + len(k) + len(v): value data
+ *
+ * This means that the first four bytes store the entire record (key + value) 
length. This format
+ * is consistent with {@link 
org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
+ * so we can pass records from this map directly into the sorter to sort 
records in place.
  */
 public final class BytesToBytesMap {
 
@@ -253,7 +261,7 @@ public final class BytesToBytesMap {
         totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, 
offsetInPage);
       }
       loc.with(currentPage, offsetInPage);
-      offsetInPage += 8 + totalLength;
+      offsetInPage += 4 + totalLength;
       currentRecordNumber++;
       return loc;
     }
@@ -366,7 +374,7 @@ public final class BytesToBytesMap {
       position += 4;
       keyLength = PlatformDependent.UNSAFE.getInt(page, position);
       position += 4;
-      valueLength = totalLength - keyLength;
+      valueLength = totalLength - keyLength - 4;
 
       keyMemoryLocation.setObjAndOffset(page, position);
 
@@ -565,7 +573,7 @@ public final class BytesToBytesMap {
       insertCursor += valueLengthBytes; // word used to store the value size
 
       PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset,
-        keyLengthBytes + valueLengthBytes);
+        keyLengthBytes + valueLengthBytes + 4);
       PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, 
keyLengthBytes);
       // Copy the key
       PlatformDependent.copyMemory(
@@ -620,7 +628,7 @@ public final class BytesToBytesMap {
    * Free all allocated memory associated with this map, including the storage 
for keys and values
    * as well as the hash map array itself.
    *
-   * This method is idempotent.
+   * This method is idempotent and can be called multiple times.
    */
   public void free() {
     longArray = null;
@@ -639,6 +647,14 @@ public final class BytesToBytesMap {
     return taskMemoryManager;
   }
 
+  public ShuffleMemoryManager getShuffleMemoryManager() {
+    return shuffleMemoryManager;
+  }
+
+  public long getPageSizeBytes() {
+    return pageSizeBytes;
+  }
+
   /** Returns the total amount of memory, in bytes, consumed by this map's 
managed structures. */
   public long getTotalMemoryConsumption() {
     long totalDataPagesSize = 0L;

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index c05f2c3..b984301 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,9 +17,12 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
 
+import javax.annotation.Nullable;
+
 import scala.runtime.AbstractFunction0;
 import scala.runtime.BoxedUnit;
 
@@ -27,7 +30,6 @@ import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.ShuffleMemoryManager;
@@ -48,7 +50,7 @@ public final class UnsafeExternalSorter {
   private final PrefixComparator prefixComparator;
   private final RecordComparator recordComparator;
   private final int initialSize;
-  private final TaskMemoryManager memoryManager;
+  private final TaskMemoryManager taskMemoryManager;
   private final ShuffleMemoryManager shuffleMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
@@ -63,26 +65,57 @@ public final class UnsafeExternalSorter {
    * this might not be necessary if we maintained a pool of re-usable pages in 
the TaskMemoryManager
    * itself).
    */
-  private final LinkedList<MemoryBlock> allocatedPages = new 
LinkedList<MemoryBlock>();
+  private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>();
+
+  private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new 
LinkedList<>();
 
   // These variables are reset after spilling:
-  private UnsafeInMemorySorter sorter;
+  private UnsafeInMemorySorter inMemSorter;
+  // Whether the in-mem sorter is created internally, or passed in from 
outside.
+  // If it is passed in from outside, we shouldn't release the in-mem sorter's 
memory.
+  private boolean isInMemSorterExternal = false;
   private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
 
-  private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new 
LinkedList<>();
+  public static UnsafeExternalSorter createWithExistingInMemorySorter(
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      BlockManager blockManager,
+      TaskContext taskContext,
+      RecordComparator recordComparator,
+      PrefixComparator prefixComparator,
+      int initialSize,
+      long pageSizeBytes,
+      UnsafeInMemorySorter inMemorySorter) throws IOException {
+    return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, 
blockManager,
+      taskContext, recordComparator, prefixComparator, initialSize, 
pageSizeBytes, inMemorySorter);
+  }
 
-  public UnsafeExternalSorter(
-      TaskMemoryManager memoryManager,
+  public static UnsafeExternalSorter create(
+      TaskMemoryManager taskMemoryManager,
       ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
-      SparkConf conf) throws IOException {
-    this.memoryManager = memoryManager;
+      long pageSizeBytes) throws IOException {
+    return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, 
blockManager,
+      taskContext, recordComparator, prefixComparator, initialSize, 
pageSizeBytes, null);
+  }
+
+  private UnsafeExternalSorter(
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      BlockManager blockManager,
+      TaskContext taskContext,
+      RecordComparator recordComparator,
+      PrefixComparator prefixComparator,
+      int initialSize,
+      long pageSizeBytes,
+      @Nullable UnsafeInMemorySorter existingInMemorySorter) throws 
IOException {
+    this.taskMemoryManager = taskMemoryManager;
     this.shuffleMemoryManager = shuffleMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
@@ -90,9 +123,18 @@ public final class UnsafeExternalSorter {
     this.prefixComparator = prefixComparator;
     this.initialSize = initialSize;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for 
units
-    this.fileBufferSizeBytes = (int) 
conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
-    this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
-    initializeForWriting();
+    // this.fileBufferSizeBytes = (int) 
conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.fileBufferSizeBytes = 32 * 1024;
+    // this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", 
"64m");
+    this.pageSizeBytes = pageSizeBytes;
+    this.writeMetrics = new ShuffleWriteMetrics();
+
+    if (existingInMemorySorter == null) {
+      initializeForWriting();
+    } else {
+      this.isInMemSorterExternal = true;
+      this.inMemSorter = existingInMemorySorter;
+    }
 
     // Register a cleanup task with TaskContext to ensure that memory is 
guaranteed to be freed at
     // the end of the task. This is necessary to avoid memory leaks in when 
the downstream operator
@@ -100,6 +142,7 @@ public final class UnsafeExternalSorter {
     taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() {
       @Override
       public BoxedUnit apply() {
+        deleteSpillFiles();
         freeMemory();
         return null;
       }
@@ -114,22 +157,31 @@ public final class UnsafeExternalSorter {
    */
   private void initializeForWriting() throws IOException {
     this.writeMetrics = new ShuffleWriteMetrics();
-    // TODO: move this sizing calculation logic into a static method of sorter:
-    final long memoryRequested = initialSize * 8L * 2;
-    final long memoryAcquired = 
shuffleMemoryManager.tryToAcquire(memoryRequested);
-    if (memoryAcquired != memoryRequested) {
+    final long pointerArrayMemory =
+      UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize);
+    final long memoryAcquired = 
shuffleMemoryManager.tryToAcquire(pointerArrayMemory);
+    if (memoryAcquired != pointerArrayMemory) {
       shuffleMemoryManager.release(memoryAcquired);
-      throw new IOException("Could not acquire " + memoryRequested + " bytes 
of memory");
+      throw new IOException("Could not acquire " + pointerArrayMemory + " 
bytes of memory");
     }
 
-    this.sorter =
-      new UnsafeInMemorySorter(memoryManager, recordComparator, 
prefixComparator, initialSize);
+    this.inMemSorter =
+      new UnsafeInMemorySorter(taskMemoryManager, recordComparator, 
prefixComparator, initialSize);
+    this.isInMemSorterExternal = false;
   }
 
   /**
-   * Sort and spill the current records in response to memory pressure.
+   * Marks the current page as no-more-space-available, and as a result, 
either allocate a
+   * new page or spill when we see the next record.
    */
   @VisibleForTesting
+  public void closeCurrentPage() {
+    freeSpaceInCurrentPage = 0;
+  }
+
+  /**
+   * Sort and spill the current records in response to memory pressure.
+   */
   public void spill() throws IOException {
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
@@ -139,9 +191,9 @@ public final class UnsafeExternalSorter {
 
     final UnsafeSorterSpillWriter spillWriter =
       new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, 
writeMetrics,
-        sorter.numRecords());
+        inMemSorter.numRecords());
     spillWriters.add(spillWriter);
-    final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
+    final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
     while (sortedRecords.hasNext()) {
       sortedRecords.loadNext();
       final Object baseObject = sortedRecords.getBaseObject();
@@ -150,20 +202,24 @@ public final class UnsafeExternalSorter {
       spillWriter.write(baseObject, baseOffset, recordLength, 
sortedRecords.getKeyPrefix());
     }
     spillWriter.close();
-    final long sorterMemoryUsage = sorter.getMemoryUsage();
-    sorter = null;
-    shuffleMemoryManager.release(sorterMemoryUsage);
     final long spillSize = freeMemory();
+    // Note that this is more-or-less going to be a multiple of the page size, 
so wasted space in
+    // pages will currently be counted as memory spilled even though that 
space isn't actually
+    // written to disk. This also counts the space needed to store the 
sorter's pointer array.
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
     initializeForWriting();
   }
 
+  /**
+   * Return the total memory usage of this sorter, including the data pages 
and the sorter's pointer
+   * array.
+   */
   private long getMemoryUsage() {
     long totalPageSize = 0;
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return sorter.getMemoryUsage() + totalPageSize;
+    return inMemSorter.getMemoryUsage() + totalPageSize;
   }
 
   @VisibleForTesting
@@ -171,13 +227,26 @@ public final class UnsafeExternalSorter {
     return allocatedPages.size();
   }
 
+  /**
+   * Free this sorter's in-memory data structures, including its data pages 
and pointer array.
+   *
+   * @return the number of bytes freed.
+   */
   public long freeMemory() {
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      memoryManager.freePage(block);
+      taskMemoryManager.freePage(block);
       shuffleMemoryManager.release(block.size());
       memoryFreed += block.size();
     }
+    if (inMemSorter != null) {
+      if (!isInMemSorterExternal) {
+        long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+        memoryFreed += sorterMemoryUsage;
+        shuffleMemoryManager.release(sorterMemoryUsage);
+      }
+      inMemSorter = null;
+    }
     allocatedPages.clear();
     currentPage = null;
     currentPagePosition = -1;
@@ -186,6 +255,20 @@ public final class UnsafeExternalSorter {
   }
 
   /**
+   * Deletes any spill files created by this sorter.
+   */
+  public void deleteSpillFiles() {
+    for (UnsafeSorterSpillWriter spill : spillWriters) {
+      File file = spill.getFile();
+      if (file != null && file.exists()) {
+        if (!file.delete()) {
+          logger.error("Was unable to delete spill file {}", 
file.getAbsolutePath());
+        };
+      }
+    }
+  }
+
+  /**
    * Checks whether there is enough space to insert a new record into the 
sorter.
    *
    * @param requiredSpace the required space in the data page, in bytes, 
including space for storing
@@ -195,7 +278,7 @@ public final class UnsafeExternalSorter {
    */
   private boolean haveSpaceForRecord(int requiredSpace) {
     assert (requiredSpace > 0);
-    return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= 
freeSpaceInCurrentPage));
+    return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= 
freeSpaceInCurrentPage));
   }
 
   /**
@@ -210,16 +293,16 @@ public final class UnsafeExternalSorter {
     // TODO: merge these steps to first calculate total memory requirements 
for this insert,
     // then try to acquire; no point in acquiring sort buffer only to spill 
due to no space in the
     // data page.
-    if (!sorter.hasSpaceForAnotherRecord()) {
+    if (!inMemSorter.hasSpaceForAnotherRecord()) {
       logger.debug("Attempting to expand sort pointer array");
-      final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
       final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
       final long memoryAcquired = 
shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
       if (memoryAcquired < memoryToGrowPointerArray) {
         shuffleMemoryManager.release(memoryAcquired);
         spill();
       } else {
-        sorter.expandPointerArray();
+        inMemSorter.expandPointerArray();
         shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
       }
     }
@@ -236,7 +319,9 @@ public final class UnsafeExternalSorter {
       } else {
         final long memoryAcquired = 
shuffleMemoryManager.tryToAcquire(pageSizeBytes);
         if (memoryAcquired < pageSizeBytes) {
-          shuffleMemoryManager.release(memoryAcquired);
+          if (memoryAcquired > 0) {
+            shuffleMemoryManager.release(memoryAcquired);
+          }
           spill();
           final long memoryAcquiredAfterSpilling = 
shuffleMemoryManager.tryToAcquire(pageSizeBytes);
           if (memoryAcquiredAfterSpilling != pageSizeBytes) {
@@ -244,7 +329,7 @@ public final class UnsafeExternalSorter {
             throw new IOException("Unable to acquire " + pageSizeBytes + " 
bytes of memory");
           }
         }
-        currentPage = memoryManager.allocatePage(pageSizeBytes);
+        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
         currentPagePosition = currentPage.getBaseOffset();
         freeSpaceInCurrentPage = pageSizeBytes;
         allocatedPages.add(currentPage);
@@ -267,7 +352,7 @@ public final class UnsafeExternalSorter {
     }
 
     final long recordAddress =
-      memoryManager.encodePageNumberAndOffset(currentPage, 
currentPagePosition);
+      taskMemoryManager.encodePageNumberAndOffset(currentPage, 
currentPagePosition);
     final Object dataPageBaseObject = currentPage.getBaseObject();
     PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, 
lengthInBytes);
     currentPagePosition += 4;
@@ -279,26 +364,48 @@ public final class UnsafeExternalSorter {
       lengthInBytes);
     currentPagePosition += lengthInBytes;
     freeSpaceInCurrentPage -= totalSpaceRequired;
-    sorter.insertRecord(recordAddress, prefix);
+    inMemSorter.insertRecord(recordAddress, prefix);
   }
 
   /**
-   * Write a record to the sorter. The record is broken down into two 
different parts, and
+   * Write a key-value record to the sorter. The key and value will be put 
together in-memory,
+   * using the following format:
    *
+   * record length (4 bytes), key length (4 bytes), key data, value data
+   *
+   * record length = key length + value length + 4
    */
-  public void insertRecord(
-      Object recordBaseObject1,
-      long recordBaseOffset1,
-      int lengthInBytes1,
-      Object recordBaseObject2,
-      long recordBaseOffset2,
-      int lengthInBytes2,
-      long prefix) throws IOException {
+  public void insertKVRecord(
+      Object keyBaseObj, long keyOffset, int keyLen,
+      Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws 
IOException {
+    final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
+    if (!haveSpaceForRecord(totalSpaceRequired)) {
+      allocateSpaceForRecord(totalSpaceRequired);
+    }
+
+    final long recordAddress =
+      taskMemoryManager.encodePageNumberAndOffset(currentPage, 
currentPagePosition);
+    final Object dataPageBaseObject = currentPage.getBaseObject();
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, 
keyLen + valueLen + 4);
+    currentPagePosition += 4;
 
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, 
keyLen);
+    currentPagePosition += 4;
+
+    PlatformDependent.copyMemory(
+      keyBaseObj, keyOffset, dataPageBaseObject, currentPagePosition, keyLen);
+    currentPagePosition += keyLen;
+
+    PlatformDependent.copyMemory(
+      valueBaseObj, valueOffset, dataPageBaseObject, currentPagePosition, 
valueLen);
+    currentPagePosition += valueLen;
+
+    freeSpaceInCurrentPage -= totalSpaceRequired;
+    inMemSorter.insertRecord(recordAddress, prefix);
   }
 
   public UnsafeSorterIterator getSortedIterator() throws IOException {
-    final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+    final UnsafeSorterIterator inMemoryIterator = 
inMemSorter.getSortedIterator();
     int numIteratorsToMerge = spillWriters.size() + 
(inMemoryIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {
       return inMemoryIterator;

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index fc34ad9..3131465 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -100,6 +100,10 @@ public final class UnsafeInMemorySorter {
     return pointerArray.length * 8L;
   }
 
+  static long getMemoryRequirementsForPointerArray(long numEntries) {
+    return numEntries * 2L * 8L;
+  }
+
   public boolean hasSpaceForAnotherRecord() {
     return pointerArrayInsertPosition + 2 < pointerArray.length;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 29e9e0f..ca1cced 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -31,6 +31,7 @@ import org.apache.spark.unsafe.PlatformDependent;
  */
 final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
 
+  private final File file;
   private InputStream in;
   private DataInputStream din;
 
@@ -48,6 +49,7 @@ final class UnsafeSorterSpillReader extends 
UnsafeSorterIterator {
       File file,
       BlockId blockId) throws IOException {
     assert (file.length() > 0);
+    this.file = file;
     final BufferedInputStream bs = new BufferedInputStream(new 
FileInputStream(file));
     this.in = blockManager.wrapForCompression(blockId, bs);
     this.din = new DataInputStream(this.in);
@@ -71,6 +73,7 @@ final class UnsafeSorterSpillReader extends 
UnsafeSorterIterator {
     numRecordsRemaining--;
     if (numRecordsRemaining == 0) {
       in.close();
+      file.delete();
       in = null;
       din = null;
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 71eed29..44cf6c7 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -140,6 +140,10 @@ final class UnsafeSorterSpillWriter {
     writeBuffer = null;
   }
 
+  public File getFile() {
+    return file;
+  }
+
   public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws 
IOException {
     return new UnsafeSorterSpillReader(blockManager, file, blockId);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 70f8ca4..dbb7c66 100644
--- 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -67,12 +67,11 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @After
   public void tearDown() {
-    if (taskMemoryManager != null) {
+    Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
+    if (shuffleMemoryManager != null) {
       long leakedShuffleMemory = 
shuffleMemoryManager.getMemoryConsumptionForThisTask();
-      Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory());
-      Assert.assertEquals(0, leakedShuffleMemory);
       shuffleMemoryManager = null;
-      taskMemoryManager = null;
+      Assert.assertEquals(0L, leakedShuffleMemory);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 0e391b7..52fa8bc 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -20,12 +20,14 @@ package org.apache.spark.util.collection.unsafe.sort;
 import java.io.File;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.util.LinkedList;
 import java.util.UUID;
 
 import scala.Tuple2;
 import scala.Tuple2$;
 import scala.runtime.AbstractFunction1;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.Mock;
@@ -33,7 +35,6 @@ import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
 import static org.mockito.AdditionalAnswers.returnsSecondArg;
 import static org.mockito.Answers.RETURNS_SMART_NULLS;
 import static org.mockito.Mockito.*;
@@ -53,7 +54,8 @@ import org.apache.spark.util.Utils;
 
 public class UnsafeExternalSorterSuite {
 
-  final TaskMemoryManager memoryManager =
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  final TaskMemoryManager taskMemoryManager =
     new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
   // Use integer comparison for comparing prefixes (which are partition ids, 
in this case)
   final PrefixComparator prefixComparator = new PrefixComparator() {
@@ -75,13 +77,15 @@ public class UnsafeExternalSorterSuite {
     }
   };
 
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager 
shuffleMemoryManager;
+  ShuffleMemoryManager shuffleMemoryManager;
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
 
   File tempDir;
 
+  private final long pageSizeBytes = new 
SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
+
   private static final class CompressStream extends 
AbstractFunction1<OutputStream, OutputStream> {
     @Override
     public OutputStream apply(OutputStream stream) {
@@ -93,15 +97,17 @@ public class UnsafeExternalSorterSuite {
   public void setUp() {
     MockitoAnnotations.initMocks(this);
     tempDir = new File(Utils.createTempDir$default$1());
+    shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
+    spillFilesCreated.clear();
     taskContext = mock(TaskContext.class);
     when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
-    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(diskBlockManager.createTempLocalBlock()).thenAnswer(new 
Answer<Tuple2<TempLocalBlockId, File>>() {
       @Override
       public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock 
invocationOnMock) throws Throwable {
         TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
         File file = File.createTempFile("spillFile", ".spill", tempDir);
+        spillFilesCreated.add(file);
         return Tuple2$.MODULE$.apply(blockId, file);
       }
     });
@@ -130,6 +136,24 @@ public class UnsafeExternalSorterSuite {
       .then(returnsSecondArg());
   }
 
+  @After
+  public void tearDown() {
+    long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+    if (shuffleMemoryManager != null) {
+      long leakedShuffleMemory = 
shuffleMemoryManager.getMemoryConsumptionForThisTask();
+      shuffleMemoryManager = null;
+      assertEquals(0L, leakedShuffleMemory);
+    }
+    assertEquals(0, leakedUnsafeMemory);
+  }
+
+  private void assertSpillFilesWereCleanedUp() {
+    for (File spillFile : spillFilesCreated) {
+      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+        spillFile.exists());
+    }
+  }
+
   private static void insertNumber(UnsafeExternalSorter sorter, int value) 
throws Exception {
     final int[] arr = new int[] { value };
     sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
@@ -138,15 +162,15 @@ public class UnsafeExternalSorterSuite {
   @Test
   public void testSortingOnlyByPrefix() throws Exception {
 
-    final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
-      memoryManager,
+    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+      taskMemoryManager,
       shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,
       prefixComparator,
-      1024,
-      new SparkConf());
+      /* initialSize */ 1024,
+      pageSizeBytes);
 
     insertNumber(sorter, 5);
     insertNumber(sorter, 1);
@@ -165,22 +189,22 @@ public class UnsafeExternalSorterSuite {
       // TODO: read rest of value.
     }
 
-    // TODO: test for cleanup:
-    // assert(tempDir.isEmpty)
+    sorter.freeMemory();
+    assertSpillFilesWereCleanedUp();
   }
 
   @Test
   public void testSortingEmptyArrays() throws Exception {
 
-    final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
-      memoryManager,
+    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+      taskMemoryManager,
       shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,
       prefixComparator,
-      1024,
-      new SparkConf());
+      /* initialSize */ 1024,
+      pageSizeBytes);
 
     sorter.insertRecord(null, 0, 0, 0);
     sorter.insertRecord(null, 0, 0, 0);
@@ -197,25 +221,30 @@ public class UnsafeExternalSorterSuite {
       assertEquals(0, iter.getKeyPrefix());
       assertEquals(0, iter.getRecordLength());
     }
+
+    sorter.freeMemory();
+    assertSpillFilesWereCleanedUp();
   }
 
   @Test
   public void testFillingPage() throws Exception {
-    final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
-      memoryManager,
+
+    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+      taskMemoryManager,
       shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,
       prefixComparator,
-      1024,
-      new SparkConf());
+      /* initialSize */ 1024,
+      pageSizeBytes);
 
     byte[] record = new byte[16];
     while (sorter.getNumberOfAllocatedPages() < 2) {
       sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, 
record.length, 0);
     }
     sorter.freeMemory();
+    assertSpillFilesWereCleanedUp();
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 1b475b2..b4fc0b7 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -507,7 +507,8 @@ public final class UnsafeRow extends MutableRow {
   public String toString() {
     StringBuilder build = new StringBuilder("[");
     for (int i = 0; i < sizeInBytes; i += 8) {
-      build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 
i));
+      build.append(java.lang.Long.toHexString(
+        PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)));
       build.append(',');
     }
     build.append(']');

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 68c49fe..5e4c623 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -59,20 +59,21 @@ final class UnsafeExternalRowSorter {
       StructType schema,
       Ordering<InternalRow> ordering,
       PrefixComparator prefixComparator,
-      PrefixComputer prefixComputer) throws IOException {
+      PrefixComputer prefixComputer,
+      long pageSizeBytes) throws IOException {
     this.schema = schema;
     this.prefixComputer = prefixComputer;
     final SparkEnv sparkEnv = SparkEnv.get();
     final TaskContext taskContext = TaskContext.get();
-    sorter = new UnsafeExternalSorter(
+    sorter = UnsafeExternalSorter.create(
       taskContext.taskMemoryManager(),
       sparkEnv.shuffleMemoryManager(),
       sparkEnv.blockManager(),
       taskContext,
       new RowComparator(ordering, schema.length()),
       prefixComparator,
-      4096,
-      sparkEnv.conf()
+      /* initialSize */ 4096,
+      pageSizeBytes
     );
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/pom.xml
----------------------------------------------------------------------
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index be09666..3490077 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -106,6 +106,11 @@
       <artifactId>parquet-avro</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.mockito</groupId>
+      <artifactId>mockito-core</artifactId>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
   <build>
     
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index a0a8dd5..9e2c933 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,24 +19,18 @@ package org.apache.spark.sql.execution;
 
 import java.io.IOException;
 
+import org.apache.spark.SparkEnv;
 import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
-import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.MemoryLocation;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
-import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
-import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter;
-import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
 
 /**
  * Unsafe-based HashMap for performing aggregations where the aggregated 
values are fixed-width.
@@ -215,7 +209,7 @@ public final class UnsafeFixedWidthAggregationMap {
   }
 
   /**
-   * Free the unsafe memory associated with this map.
+   * Free the memory associated with this map. This is idempotent and can be 
called multiple times.
    */
   public void free() {
     map.free();
@@ -233,92 +227,17 @@ public final class UnsafeFixedWidthAggregationMap {
   }
 
   /**
-   * Sorts the key, value data in this map in place, and return them as an 
iterator.
+   * Sorts the map's records in place, spill them to disk, and returns an 
[[UnsafeKVExternalSorter]]
+   * that can be used to insert more records to do external sorting.
    *
    * The only memory that is allocated is the address/prefix array, 16 bytes 
per record.
+   *
+   * Note that this destroys the map, and as a result, the map cannot be used 
anymore after this.
    */
-  public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() {
-    int numElements = map.numElements();
-    final int numKeyFields = groupingKeySchema.size();
-    TaskMemoryManager memoryManager = map.getTaskMemoryManager();
-
-    UnsafeExternalRowSorter.PrefixComputer prefixComp =
-      SortPrefixUtils.createPrefixGenerator(groupingKeySchema);
-    PrefixComparator prefixComparator = 
SortPrefixUtils.getPrefixComparator(groupingKeySchema);
-
-    final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema);
-    RecordComparator recordComparator = new RecordComparator() {
-      private final UnsafeRow row1 = new UnsafeRow();
-      private final UnsafeRow row2 = new UnsafeRow();
-
-      @Override
-      public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
-        row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
-        row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
-        return ordering.compare(row1, row2);
-      }
-    };
-
-    // Insert the records into the in-memory sorter.
-    final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
-      memoryManager, recordComparator, prefixComparator, numElements);
-
-    BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
-    UnsafeRow row = new UnsafeRow();
-    while (iter.hasNext()) {
-      final BytesToBytesMap.Location loc = iter.next();
-      final Object baseObject = loc.getKeyAddress().getBaseObject();
-      final long baseOffset = loc.getKeyAddress().getBaseOffset();
-
-      // Get encoded memory address
-      MemoryBlock page = loc.getMemoryPage();
-      long address = memoryManager.encodePageNumberAndOffset(page, baseOffset 
- 8);
-
-      // Compute prefix
-      row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
-      final long prefix = prefixComp.computePrefix(row);
-
-      sorter.insertRecord(address, prefix);
-    }
-
-    // Return the sorted result as an iterator.
-    return new KVIterator<UnsafeRow, UnsafeRow>() {
-
-      private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
-      private final UnsafeRow key = new UnsafeRow();
-      private final UnsafeRow value = new UnsafeRow();
-      private int numValueFields = aggregationBufferSchema.size();
-
-      @Override
-      public boolean next() throws IOException {
-        if (sortedIterator.hasNext()) {
-          sortedIterator.loadNext();
-          Object baseObj = sortedIterator.getBaseObject();
-          long recordOffset = sortedIterator.getBaseOffset();
-          int recordLen = sortedIterator.getRecordLength();
-          int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
-          key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
-          value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, 
recordLen - keyLen);
-          return true;
-        } else {
-          return false;
-        }
-      }
-
-      @Override
-      public UnsafeRow getKey() {
-        return key;
-      }
-
-      @Override
-      public UnsafeRow getValue() {
-        return value;
-      }
-
-      @Override
-      public void close() {
-        // Do nothing
-      }
-    };
+  public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws 
IOException {
+    UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
+      groupingKeySchema, aggregationBufferSchema,
+      SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), 
map.getPageSizeBytes(), map);
+    return sorter;
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
new file mode 100644
index 0000000..f6b0176
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import javax.annotation.Nullable;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.KVIterator;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.*;
+
+/**
+ * A class for performing external sorting on key-value records. Both key and 
value are UnsafeRows.
+ *
+ * Note that this class allows optionally passing in a {@link BytesToBytesMap} 
directly in order
+ * to perform in-place sorting of records in the map.
+ */
+public final class UnsafeKVExternalSorter {
+
+  private final StructType keySchema;
+  private final StructType valueSchema;
+  private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
+  private final UnsafeExternalSorter sorter;
+
+  public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+      BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, 
long pageSizeBytes)
+    throws IOException {
+    this(keySchema, valueSchema, blockManager, shuffleMemoryManager, 
pageSizeBytes, null);
+  }
+
+  public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+      BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, 
long pageSizeBytes,
+      @Nullable BytesToBytesMap map) throws IOException {
+    this.keySchema = keySchema;
+    this.valueSchema = valueSchema;
+    final TaskContext taskContext = TaskContext.get();
+
+    prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
+    PrefixComparator prefixComparator = 
SortPrefixUtils.getPrefixComparator(keySchema);
+    BaseOrdering ordering = GenerateOrdering.create(keySchema);
+    KVComparator recordComparator = new KVComparator(ordering, 
keySchema.length());
+
+    TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
+
+    if (map == null) {
+      sorter = UnsafeExternalSorter.create(
+        taskMemoryManager,
+        shuffleMemoryManager,
+        blockManager,
+        taskContext,
+        recordComparator,
+        prefixComparator,
+        /* initialSize */ 4096,
+        pageSizeBytes);
+    } else {
+      // Insert the records into the in-memory sorter.
+      final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
+        taskMemoryManager, recordComparator, prefixComparator, 
map.numElements());
+
+      final int numKeyFields = keySchema.size();
+      BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+      UnsafeRow row = new UnsafeRow();
+      while (iter.hasNext()) {
+        final BytesToBytesMap.Location loc = iter.next();
+        final Object baseObject = loc.getKeyAddress().getBaseObject();
+        final long baseOffset = loc.getKeyAddress().getBaseOffset();
+
+        // Get encoded memory address
+        // baseObject + baseOffset point to the beginning of the key data in 
the map, but that
+        // the KV-pair's length data is stored in the word immediately before 
that address
+        MemoryBlock page = loc.getMemoryPage();
+        long address = taskMemoryManager.encodePageNumberAndOffset(page, 
baseOffset - 8);
+
+        // Compute prefix
+        row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+        final long prefix = prefixComputer.computePrefix(row);
+
+        inMemSorter.insertRecord(address, prefix);
+      }
+
+      sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
+        taskContext.taskMemoryManager(),
+        shuffleMemoryManager,
+        blockManager,
+        taskContext,
+        new KVComparator(ordering, keySchema.length()),
+        prefixComparator,
+        /* initialSize */ 4096,
+        pageSizeBytes,
+        inMemSorter);
+
+      sorter.spill();
+      map.free();
+    }
+  }
+
+  /**
+   * Inserts a key-value record into the sorter. If the sorter no longer has 
enough memory to hold
+   * the record, the sorter sorts the existing records in-memory, writes them 
out as partially
+   * sorted runs, and then reallocates memory to hold the new record.
+   */
+  public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
+    final long prefix = prefixComputer.computePrefix(key);
+    sorter.insertKVRecord(
+      key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
+      value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), 
prefix);
+  }
+
+  public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
+    try {
+      final UnsafeSorterIterator underlying = sorter.getSortedIterator();
+      if (!underlying.hasNext()) {
+        // Since we won't ever call next() on an empty iterator, we need to 
clean up resources
+        // here in order to prevent memory leaks.
+        cleanupResources();
+      }
+
+      return new KVIterator<UnsafeRow, UnsafeRow>() {
+        private UnsafeRow key = new UnsafeRow();
+        private UnsafeRow value = new UnsafeRow();
+        private int numKeyFields = keySchema.size();
+        private int numValueFields = valueSchema.size();
+
+        @Override
+        public boolean next() throws IOException {
+          try {
+            if (underlying.hasNext()) {
+              underlying.loadNext();
+
+              Object baseObj = underlying.getBaseObject();
+              long recordOffset = underlying.getBaseOffset();
+              int recordLen = underlying.getRecordLength();
+
+              // Note that recordLen = keyLen + valueLen + 4 bytes (for the 
keyLen itself)
+              int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, 
recordOffset);
+              int valueLen = recordLen - keyLen - 4;
+
+              key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+              value.pointTo(baseObj, recordOffset + 4 + keyLen, 
numValueFields, valueLen);
+
+              return true;
+            } else {
+              key = null;
+              value = null;
+              cleanupResources();
+              return false;
+            }
+          } catch (IOException e) {
+            cleanupResources();
+            throw e;
+          }
+        }
+
+        @Override
+        public UnsafeRow getKey() {
+          return key;
+        }
+
+        @Override
+        public UnsafeRow getValue() {
+          return value;
+        }
+
+        @Override
+        public void close() {
+          cleanupResources();
+        }
+      };
+    } catch (IOException e) {
+      cleanupResources();
+      throw e;
+    }
+  }
+
+  /**
+   * Marks the current page as no-more-space-available, and as a result, 
either allocate a
+   * new page or spill when we see the next record.
+   */
+  @VisibleForTesting
+  void closeCurrentPage() {
+    sorter.closeCurrentPage();
+  }
+
+  private void cleanupResources() {
+    sorter.freeMemory();
+  }
+
+  private static final class KVComparator extends RecordComparator {
+    private final BaseOrdering ordering;
+    private final UnsafeRow row1 = new UnsafeRow();
+    private final UnsafeRow row2 = new UnsafeRow();
+    private final int numKeyFields;
+
+    public KVComparator(BaseOrdering ordering, int numKeyFields) {
+      this.numKeyFields = numKeyFields;
+      this.ordering = ordering;
+    }
+
+    @Override
+    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
+      // Note that since ordering doesn't need the total length of the record, 
we just pass -1
+      // into the row.
+      row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
+      row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+      return ordering.compare(row1, row2);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2e870ec..49adf21 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -50,17 +50,36 @@ object SortPrefixUtils {
     }
   }
 
+  /**
+   * Creates the prefix comparator for the first field in the given schema, in 
ascending order.
+   */
   def getPrefixComparator(schema: StructType): PrefixComparator = {
-    val field = schema.head
-    getPrefixComparator(SortOrder(BoundReference(0, field.dataType, 
field.nullable), Ascending))
+    if (schema.nonEmpty) {
+      val field = schema.head
+      getPrefixComparator(SortOrder(BoundReference(0, field.dataType, 
field.nullable), Ascending))
+    } else {
+      new PrefixComparator {
+        override def compare(prefix1: Long, prefix2: Long): Int = 0
+      }
+    }
   }
 
+  /**
+   * Creates the prefix computer for the first field in the given schema, in 
ascending order.
+   */
   def createPrefixGenerator(schema: StructType): 
UnsafeExternalRowSorter.PrefixComputer = {
-    val boundReference = BoundReference(0, schema.head.dataType, nullable = 
true)
-    val prefixProjection = 
UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending)))
-    new UnsafeExternalRowSorter.PrefixComputer {
-      override def computePrefix(row: InternalRow): Long = {
-        prefixProjection.apply(row).getLong(0)
+    if (schema.nonEmpty) {
+      val boundReference = BoundReference(0, schema.head.dataType, nullable = 
true)
+      val prefixProjection = UnsafeProjection.create(
+        SortPrefix(SortOrder(boundReference, Ascending)))
+      new UnsafeExternalRowSorter.PrefixComputer {
+        override def computePrefix(row: InternalRow): Long = {
+          prefixProjection.apply(row).getLong(0)
+        }
+      }
+    } else {
+      new UnsafeExternalRowSorter.PrefixComputer {
+        override def computePrefix(row: InternalRow): Long = 0
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 6d903ab..92cf328 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -116,6 +116,7 @@ case class TungstenSort(
   protected override def doExecute(): RDD[InternalRow] = {
     val schema = child.schema
     val childOutput = child.output
+    val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", 
"64m")
     child.execute().mapPartitions({ iter =>
       val ordering = newOrdering(sortOrder, childOutput)
 
@@ -131,7 +132,8 @@ case class TungstenSort(
         }
       }
 
-      val sorter = new UnsafeExternalRowSorter(schema, ordering, 
prefixComparator, prefixComputer)
+      val sorter = new UnsafeExternalRowSorter(
+        schema, ordering, prefixComparator, prefixComputer, pageSize)
       if (testSpillFrequency > 0) {
         sorter.setTestSpillFrequency(testSpillFrequency)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
new file mode 100644
index 0000000..53de2d0
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.shuffle.ShuffleMemoryManager
+
+/**
+ * A [[ShuffleMemoryManager]] that can be controlled to run out of memory.
+ */
+class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) {
+  private var oom = false
+
+  override def tryToAcquire(numBytes: Long): Long = {
+    if (oom) {
+      oom = false
+      0
+    } else {
+      // Uncomment the following to trace memory allocations.
+      // println(s"tryToAcquire $numBytes in " +
+      //   Thread.currentThread().getStackTrace.mkString("", "\n  -", ""))
+      val acquired = super.tryToAcquire(numBytes)
+      acquired
+    }
+  }
+
+  override def release(numBytes: Long): Unit = {
+    // Uncomment the following to trace memory releases.
+    // println(s"release $numBytes in " +
+    //   Thread.currentThread().getStackTrace.mkString("", "\n  -", ""))
+    super.release(numBytes)
+  }
+
+  def markAsOutOfMemory(): Unit = {
+    oom = true
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 098bdd0..4c94b33 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -17,24 +17,26 @@
 
 package org.apache.spark.sql.execution
 
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
-import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
 import scala.collection.mutable
-import scala.util.Random
+import scala.util.{Try, Random}
+
+import org.scalatest.Matchers
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, 
TaskMemoryManager}
 import org.apache.spark.unsafe.types.UTF8String
 
-
-class UnsafeFixedWidthAggregationMapSuite
-  extends SparkFunSuite
-  with Matchers
-  with BeforeAndAfterEach {
+/**
+ * Test suite for [[UnsafeFixedWidthAggregationMap]].
+ *
+ * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test 
cases.
+ */
+class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
 
   import UnsafeFixedWidthAggregationMap._
 
@@ -44,23 +46,40 @@ class UnsafeFixedWidthAggregationMapSuite
   private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
 
   private var taskMemoryManager: TaskMemoryManager = null
-  private var shuffleMemoryManager: ShuffleMemoryManager = null
+  private var shuffleMemoryManager: TestShuffleMemoryManager = null
+
+  def testWithMemoryLeakDetection(name: String)(f: => Unit) {
+    def cleanup(): Unit = {
+      if (taskMemoryManager != null) {
+        val leakedShuffleMemory = 
shuffleMemoryManager.getMemoryConsumptionForThisTask()
+        assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+        assert(leakedShuffleMemory === 0)
+        taskMemoryManager = null
+      }
+    }
 
-  override def beforeEach(): Unit = {
-    taskMemoryManager = new TaskMemoryManager(new 
ExecutorMemoryManager(MemoryAllocator.HEAP))
-    shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
+    test(name) {
+      taskMemoryManager = new TaskMemoryManager(new 
ExecutorMemoryManager(MemoryAllocator.HEAP))
+      shuffleMemoryManager = new TestShuffleMemoryManager
+      try {
+        f
+      } catch {
+        case NonFatal(e) =>
+          Try(cleanup())
+          throw e
+      }
+      cleanup()
+    }
   }
 
-  override def afterEach(): Unit = {
-    if (taskMemoryManager != null) {
-      val leakedShuffleMemory = 
shuffleMemoryManager.getMemoryConsumptionForThisTask()
-      assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
-      assert(leakedShuffleMemory === 0)
-      taskMemoryManager = null
-    }
+  private def randomStrings(n: Int): Seq[String] = {
+    val rand = new Random(42)
+    Seq.fill(512) {
+      Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+    }.distinct
   }
 
-  test("supported schemas") {
+  testWithMemoryLeakDetection("supported schemas") {
     assert(supportsAggregationBufferSchema(
       StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
     assert(!supportsAggregationBufferSchema(
@@ -70,7 +89,7 @@ class UnsafeFixedWidthAggregationMapSuite
       !supportsAggregationBufferSchema(StructType(StructField("x", 
ArrayType(IntegerType)) :: Nil)))
   }
 
-  test("empty map") {
+  testWithMemoryLeakDetection("empty map") {
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
@@ -85,7 +104,7 @@ class UnsafeFixedWidthAggregationMapSuite
     map.free()
   }
 
-  test("updating values for a single key") {
+  testWithMemoryLeakDetection("updating values for a single key") {
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
@@ -113,7 +132,7 @@ class UnsafeFixedWidthAggregationMapSuite
     map.free()
   }
 
-  test("inserting large random keys") {
+  testWithMemoryLeakDetection("inserting large random keys") {
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
@@ -140,7 +159,21 @@ class UnsafeFixedWidthAggregationMapSuite
     map.free()
   }
 
-  test("test sorting") {
+  testWithMemoryLeakDetection("test external sorting") {
+    // Calling this make sure we have block manager and everything else setup.
+    TestSQLContext
+
+    TaskContext.setTaskContext(new TaskContextImpl(
+      stageId = 0,
+      partitionId = 0,
+      taskAttemptId = 0,
+      attemptNumber = 0,
+      taskMemoryManager = taskMemoryManager,
+      metricsSystem = null))
+
+    // Memory consumption in the beginning of the task.
+    val initialMemoryConsumption = 
shuffleMemoryManager.getMemoryConsumptionForThisTask()
+
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
@@ -152,26 +185,47 @@ class UnsafeFixedWidthAggregationMapSuite
       false // disable perf metrics
     )
 
-    val rand = new Random(42)
-    val groupKeys: Set[String] = Seq.fill(512) {
-      Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
-    }.toSet
-    groupKeys.foreach { keyString =>
+    val keys = randomStrings(1024).take(512)
+    keys.foreach { keyString =>
       val buf = 
map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
       buf.setInt(0, keyString.length)
       assert(buf != null)
     }
 
+    // Convert the map into a sorter
+    val sorter = map.destructAndCreateExternalSorter()
+
+    withClue(s"destructAndCreateExternalSorter should release memory used by 
the map") {
+      // 4096 * 16 is the initial size allocated for the pointer/prefix array 
in the in-mem sorter.
+      assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() ===
+        initialMemoryConsumption + 4096 * 16)
+    }
+
+    // Add more keys to the sorter and make sure the results come out sorted.
+    val additionalKeys = randomStrings(1024)
+    val keyConverter = UnsafeProjection.create(groupKeySchema)
+    val valueConverter = UnsafeProjection.create(aggBufferSchema)
+
+    additionalKeys.zipWithIndex.foreach { case (str, i) =>
+      val k = InternalRow(UTF8String.fromString(str))
+      val v = InternalRow(str.length)
+      sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+      if ((i % 100) == 0) {
+        shuffleMemoryManager.markAsOutOfMemory()
+        sorter.closeCurrentPage()
+      }
+    }
+
     val out = new scala.collection.mutable.ArrayBuffer[String]
-    val iter = map.sortedIterator()
+    val iter = sorter.sortedIterator()
     while (iter.next()) {
       assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
       out += iter.getKey.getString(0)
     }
 
-    assert(out === groupKeys.toSeq.sorted)
+    assert(out === (keys ++ additionalKeys).sorted)
 
     map.free()
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2e981b7b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
new file mode 100644
index 0000000..5d214d7
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{RowOrdering, 
UnsafeProjection}
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, 
TaskMemoryManager}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark._
+
+class UnsafeKVExternalSorterSuite extends SparkFunSuite {
+
+  test("sorting string key and int int value") {
+
+    // Calling this make sure we have block manager and everything else setup.
+    TestSQLContext
+
+    val taskMemMgr = new TaskMemoryManager(new 
ExecutorMemoryManager(MemoryAllocator.HEAP))
+    val shuffleMemMgr = new TestShuffleMemoryManager
+
+    TaskContext.setTaskContext(new TaskContextImpl(
+      stageId = 0,
+      partitionId = 0,
+      taskAttemptId = 0,
+      attemptNumber = 0,
+      taskMemoryManager = taskMemMgr,
+      metricsSystem = null))
+
+    val keySchema = new StructType().add("a", StringType)
+    val valueSchema = new StructType().add("b", IntegerType).add("c", 
IntegerType)
+    val sorter = new UnsafeKVExternalSorter(
+      keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
+      16 * 1024)
+
+    val keyConverter = UnsafeProjection.create(keySchema)
+    val valueConverter = UnsafeProjection.create(valueSchema)
+
+    val rand = new Random(42)
+    val data = null +: Seq.fill[String](10) {
+      Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+    }
+
+    val inputRows = data.map { str =>
+      keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy()
+    }
+
+    var i = 0
+    data.foreach { str =>
+      if (str != null) {
+        val k = InternalRow(UTF8String.fromString(str))
+        val v = InternalRow(str.length, str.length + 1)
+        sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+      } else {
+        val k = InternalRow(UTF8String.fromString(str))
+        val v = InternalRow(-1, -2)
+        sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+      }
+
+      if ((i % 100) == 0) {
+        shuffleMemMgr.markAsOutOfMemory()
+        sorter.closeCurrentPage()
+      }
+      i += 1
+    }
+
+    val out = new scala.collection.mutable.ArrayBuffer[InternalRow]
+    val iter = sorter.sortedIterator()
+    while (iter.next()) {
+      if (iter.getKey.getUTF8String(0) == null) {
+        withClue(s"for null key") {
+          assert(-1 === iter.getValue.getInt(0))
+          assert(-2 === iter.getValue.getInt(1))
+        }
+      } else {
+        val key = iter.getKey.getString(0)
+        withClue(s"for key $key") {
+          assert(key.length === iter.getValue.getInt(0))
+          assert(key.length + 1 === iter.getValue.getInt(1))
+        }
+      }
+      out += iter.getKey.copy()
+    }
+
+    assert(out === 
inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType))))
+  }
+
+  test("sorting arbitrary string data") {
+
+    // Calling this make sure we have block manager and everything else setup.
+    TestSQLContext
+
+    val taskMemMgr = new TaskMemoryManager(new 
ExecutorMemoryManager(MemoryAllocator.HEAP))
+    val shuffleMemMgr = new TestShuffleMemoryManager
+
+    TaskContext.setTaskContext(new TaskContextImpl(
+      stageId = 0,
+      partitionId = 0,
+      taskAttemptId = 0,
+      attemptNumber = 0,
+      taskMemoryManager = taskMemMgr,
+      metricsSystem = null))
+
+    val keySchema = new StructType().add("a", StringType)
+    val valueSchema = new StructType().add("b", IntegerType)
+    val sorter = new UnsafeKVExternalSorter(
+      keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
+      16 * 1024)
+
+    val keyConverter = UnsafeProjection.create(keySchema)
+    val valueConverter = UnsafeProjection.create(valueSchema)
+
+    val rand = new Random(42)
+    val data = Seq.fill(512) {
+      Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+    }
+
+    var i = 0
+    data.foreach { str =>
+      val k = InternalRow(UTF8String.fromString(str))
+      val v = InternalRow(str.length)
+      sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+      if ((i % 100) == 0) {
+        shuffleMemMgr.markAsOutOfMemory()
+        sorter.closeCurrentPage()
+      }
+      i += 1
+    }
+
+    val out = new scala.collection.mutable.ArrayBuffer[String]
+    val iter = sorter.sortedIterator()
+    while (iter.next()) {
+      assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
+      out += iter.getKey.getString(0)
+    }
+
+    assert(out === data.sorted)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to