Repository: spark
Updated Branches:
refs/heads/branch-2.0 0cf31f0c8 -> beb753004
[SPARK-14851][CORE] Support radix sort with nullable longs
## What changes were proposed in this pull request?
This adds support for radix sort of nullable long fields. When a sort field is
null and radix sort is enabled, we keep nulls in a separate region of the sort
buffer so that radix sort does not need to deal with them. This also has
performance benefits when sorting smaller integer types, since the current
representation of nulls in two's complement (Long.MIN_VALUE) otherwise forces a
full-width radix sort.
This strategy for nulls does mean the sort is no longer stable. cc davies
## How was this patch tested?
Existing randomized sort tests for correctness. I also tested some TPCDS
queries and there does not seem to be any significant regression for non-null
sorts.
Some test queries (best of 5 runs each).
Before change:
scala> val start = System.nanoTime; spark.range(5000000).selectExpr("if(id > 5,
cast(hash(id) as long), NULL) as h").coalesce(1).orderBy("h").collect();
(System.nanoTime - start) / 1e6
start: Long = 3190437233227987
res3: Double = 4716.471091
After change:
scala> val start = System.nanoTime; spark.range(5000000).selectExpr("if(id > 5,
cast(hash(id) as long), NULL) as h").coalesce(1).orderBy("h").collect();
(System.nanoTime - start) / 1e6
start: Long = 3190367870952791
res4: Double = 2981.143045
Author: Eric Liang <[email protected]>
Closes #13161 from ericl/sc-2998.
(cherry picked from commit c06c58bbbb2de0c22cfc70c486d23a94c3079ba4)
Signed-off-by: Reynold Xin <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/beb75300
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/beb75300
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/beb75300
Branch: refs/heads/branch-2.0
Commit: beb75300455a4f92000b69e740256102d9f2d472
Parents: 0cf31f0
Author: Eric Liang <[email protected]>
Authored: Sat Jun 11 15:42:58 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Sat Jun 11 15:43:03 2016 -0700
----------------------------------------------------------------------
.../util/collection/unsafe/sort/RadixSort.java | 24 +++++----
.../unsafe/sort/UnsafeExternalSorter.java | 11 ++--
.../unsafe/sort/UnsafeInMemorySorter.java | 56 ++++++++++++++++----
.../unsafe/sort/UnsafeExternalSorterSuite.java | 26 ++++-----
.../unsafe/sort/UnsafeInMemorySorterSuite.java | 2 +-
.../collection/unsafe/sort/RadixSortSuite.scala | 4 +-
.../sql/execution/UnsafeExternalRowSorter.java | 20 +++++--
.../sql/catalyst/expressions/SortOrder.scala | 40 ++++++++------
.../sql/execution/UnsafeKVExternalSorter.java | 11 ++--
.../apache/spark/sql/execution/SortExec.scala | 12 +++--
.../spark/sql/execution/SortPrefixUtils.scala | 32 +++++++----
.../apache/spark/sql/execution/WindowExec.scala | 4 +-
.../execution/joins/CartesianProductExec.scala | 2 +-
.../apache/spark/sql/execution/SortSuite.scala | 11 ++++
.../sql/execution/benchmark/SortBenchmark.scala | 2 +-
15 files changed, 178 insertions(+), 79 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
----------------------------------------------------------------------
diff --git
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
index 4f3f0de..4043617 100644
---
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
+++
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -170,9 +170,13 @@ public class RadixSort {
/**
* Specialization of sort() for key-prefix arrays. In this type of array,
each record consists
* of two longs, only the second of which is sorted on.
+ *
+ * @param startIndex starting index in the array to sort from. This
parameter is not supported
+ * in the plain sort() implementation.
*/
public static int sortKeyPrefixArray(
LongArray array,
+ int startIndex,
int numRecords,
int startByteIndex,
int endByteIndex,
@@ -182,10 +186,11 @@ public class RadixSort {
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <=
7";
assert endByteIndex > startByteIndex;
assert numRecords * 4 <= array.size();
- int inIndex = 0;
- int outIndex = numRecords * 2;
+ int inIndex = startIndex;
+ int outIndex = startIndex + numRecords * 2;
if (numRecords > 0) {
- long[][] counts = getKeyPrefixArrayCounts(array, numRecords,
startByteIndex, endByteIndex);
+ long[][] counts = getKeyPrefixArrayCounts(
+ array, startIndex, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (counts[i] != null) {
sortKeyPrefixArrayAtByte(
@@ -205,13 +210,14 @@ public class RadixSort {
* getCounts with some added parameters but that seems to hurt in benchmarks.
*/
private static long[][] getKeyPrefixArrayCounts(
- LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ LongArray array, int startIndex, int numRecords, int startByteIndex, int
endByteIndex) {
long[][] counts = new long[8][];
long bitwiseMax = 0;
long bitwiseMin = -1L;
- long limit = array.getBaseOffset() + numRecords * 16;
+ long baseOffset = array.getBaseOffset() + startIndex * 8L;
+ long limit = baseOffset + numRecords * 16L;
Object baseObject = array.getBaseObject();
- for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ for (long offset = baseOffset; offset < limit; offset += 16) {
long value = Platform.getLong(baseObject, offset + 8);
bitwiseMax |= value;
bitwiseMin &= value;
@@ -220,7 +226,7 @@ public class RadixSort {
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
counts[i] = new long[256];
- for (long offset = array.getBaseOffset(); offset < limit; offset +=
16) {
+ for (long offset = baseOffset; offset < limit; offset += 16) {
counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i *
8)) & 0xff)]++;
}
}
@@ -238,8 +244,8 @@ public class RadixSort {
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc,
signed);
Object baseObject = array.getBaseObject();
- long baseOffset = array.getBaseOffset() + inIndex * 8;
- long maxOffset = baseOffset + numRecords * 16;
+ long baseOffset = array.getBaseOffset() + inIndex * 8L;
+ long maxOffset = baseOffset + numRecords * 16L;
for (long offset = baseOffset; offset < maxOffset; offset += 16) {
long key = Platform.getLong(baseObject, offset);
long prefix = Platform.getLong(baseObject, offset + 8);
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e14a23f..ec15f0b 100644
---
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -369,7 +369,8 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
/**
* Write a record to the sorter.
*/
- public void insertRecord(Object recordBase, long recordOffset, int length,
long prefix)
+ public void insertRecord(
+ Object recordBase, long recordOffset, int length, long prefix, boolean
prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
@@ -384,7 +385,7 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -396,7 +397,7 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
* record length = key length + value length + 4
*/
public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
- Object valueBase, long valueOffset, int valueLen, long prefix)
+ Object valueBase, long valueOffset, int valueLen, long prefix, boolean
prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
@@ -415,7 +416,7 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
pageCursor += valueLen;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -465,7 +466,7 @@ public final class UnsafeExternalSorter extends
MemoryConsumer {
private boolean loaded = false;
private int numRecords = 0;
- SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+ SpillableIterator(UnsafeSorterIterator inMemIterator) {
this.upstream = inMemIterator;
this.numRecords = inMemIterator.getNumRecords();
}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index c7b070f..78da389 100644
---
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -18,6 +18,7 @@
package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
+import java.util.LinkedList;
import org.apache.avro.reflect.Nullable;
@@ -93,6 +94,14 @@ public final class UnsafeInMemorySorter {
private int pos = 0;
/**
+ * If sorting with radix sort, specifies the starting position in the sort
buffer where records
+ * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will
contain null-prefixed
+ * records, and positions [nullBoundaryPos..pos) non-null prefixed records.
This lets us avoid
+ * radix sorting over null values.
+ */
+ private int nullBoundaryPos = 0;
+
+ /*
* How many records could be inserted, because part of the array should be
left for sorting.
*/
private int usableCapacity = 0;
@@ -160,6 +169,7 @@ public final class UnsafeInMemorySorter {
usableCapacity = getUsableCapacity();
}
pos = 0;
+ nullBoundaryPos = 0;
}
/**
@@ -206,14 +216,27 @@ public final class UnsafeInMemorySorter {
* @param recordPointer pointer to a record in a data page, encoded by
{@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
*/
- public void insertRecord(long recordPointer, long keyPrefix) {
+ public void insertRecord(long recordPointer, long keyPrefix, boolean
prefixIsNull) {
if (!hasSpaceForAnotherRecord()) {
throw new IllegalStateException("There is no space for new record");
}
- array.set(pos, recordPointer);
- pos++;
- array.set(pos, keyPrefix);
- pos++;
+ if (prefixIsNull && radixSortSupport != null) {
+ // Swap forward a non-null record to make room for this one at the
beginning of the array.
+ array.set(pos, array.get(nullBoundaryPos));
+ pos++;
+ array.set(pos, array.get(nullBoundaryPos + 1));
+ pos++;
+ // Place this record in the vacated position.
+ array.set(nullBoundaryPos, recordPointer);
+ nullBoundaryPos++;
+ array.set(nullBoundaryPos, keyPrefix);
+ nullBoundaryPos++;
+ } else {
+ array.set(pos, recordPointer);
+ pos++;
+ array.set(pos, keyPrefix);
+ pos++;
+ }
}
public final class SortedIterator extends UnsafeSorterIterator implements
Cloneable {
@@ -280,15 +303,14 @@ public final class UnsafeInMemorySorter {
* Return an iterator over record pointers in sorted order. For efficiency,
all calls to
* {@code next()} will return the same mutable object.
*/
- public SortedIterator getSortedIterator() {
+ public UnsafeSorterIterator getSortedIterator() {
int offset = 0;
long start = System.nanoTime();
if (sortComparator != null) {
if (this.radixSortSupport != null) {
- // TODO(ekl) we should handle NULL values before radix sort for
efficiency, since they
- // force a full-width sort (and we cannot radix-sort nullable long
fields at all).
offset = RadixSort.sortKeyPrefixArray(
- array, pos / 2, 0, 7, radixSortSupport.sortDescending(),
radixSortSupport.sortSigned());
+ array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
+ radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
@@ -301,6 +323,20 @@ public final class UnsafeInMemorySorter {
}
}
totalSortTimeNanos += System.nanoTime() - start;
- return new SortedIterator(pos / 2, offset);
+ if (nullBoundaryPos > 0) {
+ assert radixSortSupport != null : "Nulls are only stored separately with
radix sort";
+ LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
+ if (radixSortSupport.sortDescending()) {
+ // Nulls are smaller than non-nulls
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ } else {
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ }
+ return new UnsafeExternalSorter.ChainedIterator(queue);
+ } else {
+ return new SortedIterator(pos / 2, offset);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 2cae4be..bce958c 100644
---
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -156,14 +156,14 @@ public class UnsafeExternalSorterSuite {
private static void insertNumber(UnsafeExternalSorter sorter, int value)
throws Exception {
final int[] arr = new int[]{ value };
- sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value);
+ sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value, false);
}
private static void insertRecord(
UnsafeExternalSorter sorter,
int[] record,
long prefix) throws IOException {
- sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4,
prefix);
+ sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4,
prefix, false);
}
private UnsafeExternalSorter newSorter() throws IOException {
@@ -206,13 +206,13 @@ public class UnsafeExternalSorterSuite {
@Test
public void testSortingEmptyArrays() throws Exception {
final UnsafeExternalSorter sorter = newSorter();
- sorter.insertRecord(null, 0, 0, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
- sorter.insertRecord(null, 0, 0, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
+ sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
@@ -232,7 +232,7 @@ public class UnsafeExternalSorterSuite {
long prevSortTime = sorter.getSortTimeNanos();
assertEquals(prevSortTime, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
prevSortTime = sorter.getSortTimeNanos();
@@ -240,7 +240,7 @@ public class UnsafeExternalSorterSuite {
sorter.spill(); // no sort needed
assertEquals(sorter.getSortTimeNanos(), prevSortTime);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
}
@@ -280,7 +280,7 @@ public class UnsafeExternalSorterSuite {
final UnsafeExternalSorter sorter = newSorter();
byte[] record = new byte[16];
while (sorter.getNumberOfAllocatedPages() < 2) {
- sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length,
0);
+ sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length,
0, false);
}
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
@@ -340,7 +340,7 @@ public class UnsafeExternalSorterSuite {
int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0,
false);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
@@ -372,7 +372,7 @@ public class UnsafeExternalSorterSuite {
int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0,
false);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
@@ -406,7 +406,7 @@ public class UnsafeExternalSorterSuite {
int batch = n / 4;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0,
false);
if (i % batch == batch - 1) {
sorter.spill();
}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 383c5b3..bd89085 100644
---
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -120,7 +120,7 @@ public class UnsafeInMemorySorterSuite {
final long address = memoryManager.encodePageNumberAndOffset(dataPage,
position);
final String str = getStringFromDataPage(baseObject, position + 4,
recordLength);
final int partitionId = hashPartitioner.getPartition(str);
- sorter.insertRecord(address, partitionId);
+ sorter.insertRecord(address, partitionId, false);
position += 4 + recordLength;
}
final UnsafeSorterIterator iter = sorter.getSortedIterator();
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
----------------------------------------------------------------------
diff --git
a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index 1d26d4a..2c13806 100644
---
a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++
b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -152,7 +152,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff)
referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
val outOffset = RadixSort.sortKeyPrefixArray(
- buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx,
sortType.descending, sortType.signed)
val res1 = collectToArray(buf1, 0, N * 2)
val res2 = collectToArray(buf2, outOffset, N * 2)
@@ -177,7 +177,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask)
referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
val outOffset = RadixSort.sortKeyPrefixArray(
- buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx,
sortType.descending, sortType.signed)
val res1 = collectToArray(buf1, 0, N * 2)
val res2 = collectToArray(buf2, outOffset, N * 2)
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 37fbad4..ad76bf5 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -51,7 +51,20 @@ public final class UnsafeExternalRowSorter {
private final UnsafeExternalSorter sorter;
public abstract static class PrefixComputer {
- abstract long computePrefix(InternalRow row);
+
+ public static class Prefix {
+ /** Key prefix value, or the null prefix value if isNull = true. **/
+ long value;
+
+ /** Whether the key is null. */
+ boolean isNull;
+ }
+
+ /**
+ * Computes prefix for the given row. For efficiency, the returned object
may be reused in
+ * further calls to a given PrefixComputer.
+ */
+ abstract Prefix computePrefix(InternalRow row);
}
public UnsafeExternalRowSorter(
@@ -88,12 +101,13 @@ public final class UnsafeExternalRowSorter {
}
public void insertRow(UnsafeRow row) throws IOException {
- final long prefix = prefixComputer.computePrefix(row);
+ final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
row.getBaseOffset(),
row.getSizeInBytes(),
- prefix
+ prefix.value,
+ prefix.isNull
);
numRowsInserted++;
if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0)
{
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 42a8be6..de779ed 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -64,10 +64,21 @@ case class SortOrder(child: Expression, direction:
SortDirection)
}
/**
- * An expression to generate a 64-bit long prefix used in sorting.
+ * An expression to generate a 64-bit long prefix used in sorting. If the sort
must operate over
+ * null keys as well, this.nullValue can be used in place of emitted null
prefixes in the sort.
*/
case class SortPrefix(child: SortOrder) extends UnaryExpression {
+ val nullValue = child.child.dataType match {
+ case BooleanType | DateType | TimestampType | _: IntegralType =>
+ Long.MinValue
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS
=>
+ Long.MinValue
+ case _: DecimalType =>
+ DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
+ case _ => 0L
+ }
+
override def eval(input: InternalRow): Any = throw new
UnsupportedOperationException
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -75,20 +86,19 @@ case class SortPrefix(child: SortOrder) extends
UnaryExpression {
val input = childCode.value
val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
-
- val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+ val prefixCode = child.child.dataType match {
case BooleanType =>
- (Long.MinValue, s"$input ? 1L : 0L")
+ s"$input ? 1L : 0L"
case _: IntegralType =>
- (Long.MinValue, s"(long) $input")
+ s"(long) $input"
case DateType | TimestampType =>
- (Long.MinValue, s"(long) $input")
+ s"(long) $input"
case FloatType | DoubleType =>
- (0L, s"$DoublePrefixCmp.computePrefix((double)$input)")
- case StringType => (0L, s"$input.getPrefix()")
- case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
+ s"$DoublePrefixCmp.computePrefix((double)$input)"
+ case StringType => s"$input.getPrefix()"
+ case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)"
case dt: DecimalType if dt.precision - dt.scale <=
Decimal.MAX_LONG_DIGITS =>
- val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+ if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
s"$input.toUnscaledLong()"
} else {
// reduce the scale to fit in a long
@@ -96,17 +106,15 @@ case class SortPrefix(child: SortOrder) extends
UnaryExpression {
val s = p - (dt.precision - dt.scale)
s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() :
${Long.MinValue}L"
}
- (Long.MinValue, prefix)
case dt: DecimalType =>
- (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
- s"$DoublePrefixCmp.computePrefix($input.toDouble())")
- case _ => (0L, "0L")
+ s"$DoublePrefixCmp.computePrefix($input.toDouble())"
+ case _ => "0L"
}
ev.copy(code = childCode.code +
s"""
- |long ${ev.value} = ${nullValue}L;
- |boolean ${ev.isNull} = false;
+ |long ${ev.value} = 0L;
+ |boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
| ${ev.value} = $prefixCode;
|}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index bb823cd..99fe51d 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -118,9 +118,10 @@ public final class UnsafeKVExternalSorter {
// Compute prefix
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
- final long prefix = prefixComputer.computePrefix(row);
+ final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix =
+ prefixComputer.computePrefix(row);
- inMemSorter.insertRecord(address, prefix);
+ inMemSorter.insertRecord(address, prefix.value, prefix.isNull);
}
sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
@@ -146,10 +147,12 @@ public final class UnsafeKVExternalSorter {
* sorted runs, and then reallocates memory to hold the new record.
*/
public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
- final long prefix = prefixComputer.computePrefix(key);
+ final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix =
+ prefixComputer.computePrefix(key);
sorter.insertKVRecord(
key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
- value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(),
prefix);
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(),
+ prefix.value, prefix.isNull);
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 66a16ac..6db7f45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -68,10 +68,16 @@ case class SortExec(
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)
// The generator for prefix
- val prefixProjection =
UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixExpr = SortPrefix(boundSortExpression)
+ val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ val prefix = prefixProjection.apply(row)
+ result.isNull = prefix.isNullAt(0)
+ result.value = if (result.isNull) prefixExpr.nullValue else
prefix.getLong(0)
+ result
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 1a5ff5f..940467e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -33,6 +33,11 @@ object SortPrefixUtils {
override def compare(prefix1: Long, prefix2: Long): Int = 0
}
+ /**
+ * Dummy sort prefix result to use for empty rows.
+ */
+ private val emptyPrefix = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
case StringType =>
@@ -70,10 +75,6 @@ object SortPrefixUtils {
*/
def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = {
sortOrder.dataType match {
- // TODO(ekl) long-type is problematic because it's null prefix
representation collides with
- // the lowest possible long value. Handle this special case outside
radix sort.
- case LongType if sortOrder.nullable =>
- false
case BooleanType | ByteType | ShortType | IntegerType | LongType |
DateType |
TimestampType | FloatType | DoubleType =>
true
@@ -97,16 +98,29 @@ object SortPrefixUtils {
def createPrefixGenerator(schema: StructType):
UnsafeExternalRowSorter.PrefixComputer = {
if (schema.nonEmpty) {
val boundReference = BoundReference(0, schema.head.dataType, nullable =
true)
- val prefixProjection = UnsafeProjection.create(
- SortPrefix(SortOrder(boundReference, Ascending)))
+ val prefixExpr = SortPrefix(SortOrder(boundReference, Ascending))
+ val prefixProjection = UnsafeProjection.create(prefixExpr)
new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ val prefix = prefixProjection.apply(row)
+ if (prefix.isNullAt(0)) {
+ result.isNull = true
+ result.value = prefixExpr.nullValue
+ } else {
+ result.isNull = false
+ result.value = prefix.getLong(0)
+ }
+ result
}
}
} else {
new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = 0
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ emptyPrefix
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
index 97bbab6..1b9634c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -347,13 +347,13 @@ case class WindowExec(
SparkEnv.get.memoryManager.pageSizeBytes,
false)
rows.foreach { r =>
- sorter.insertRecord(r.getBaseObject, r.getBaseOffset,
r.getSizeInBytes, 0)
+ sorter.insertRecord(r.getBaseObject, r.getBaseOffset,
r.getSizeInBytes, 0, false)
}
rows.clear()
}
} else {
sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
- nextRow.getSizeInBytes, 0)
+ nextRow.getSizeInBytes, 0, false)
}
fetchNextRow()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 88f78a7..d870d91 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -53,7 +53,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right :
RDD[UnsafeRow], numField
val partition = split.asInstanceOf[CartesianPartition]
for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes,
0)
+ sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes,
0, false)
}
// Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index c3acf29..ba3fa37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -54,6 +54,17 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = false)
}
+ test("sorting all nulls") {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"),
+ (child: SparkPlan) =>
+ GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child =
child)),
+ (child: SparkPlan) =>
+ GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true,
child)),
+ sortAnswers = false
+ )
+ }
+
test("sort followed by limit") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
http://git-wip-us.apache.org/repos/asf/spark/blob/beb75300/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
index 9964b73..50ae26a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
@@ -110,7 +110,7 @@ class SortBenchmark extends BenchmarkBase {
benchmark.addTimerCase("radix sort key prefix array") { timer =>
val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong)
timer.startTiming()
- RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false)
+ RadixSort.sortKeyPrefixArray(buf2, 0, size, 0, 7, false, false)
timer.stopTiming()
}
benchmark.run()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]