Repository: spark
Updated Branches:
  refs/heads/master ee15404a2 -> 5be6b0e4f


[SPARK-6195] [SQL] Adds in-memory column type for fixed-precision decimals

This PR adds a specialized in-memory column type for fixed-precision decimals.

For all other column types, a single integer column type ID is enough to 
determine which column type to use. However, this doesn't apply to 
fixed-precision decimal types with different precision and scale parameters. 
Moreover, according to the previous design, there seems no trivial way to 
encode precision and scale information into the columnar byte buffer. On the 
other hand, considering we always know the data type of the column to be built 
/ scanned ahead of time. This PR no longer use column type ID to construct 
`ColumnBuilder`s and `ColumnAccessor`s, but resorts to the actual column data 
type. In this way, we can pass precision / scale information along the way.

The column type ID is now not used anymore and can be removed in a future PR.

### Micro benchmark result

The following micro benchmark builds a simple table with 2 million decimals 
(precision = 10, scale = 0), cache it in memory, then count all the rows. Code 
(simply paste it into Spark shell):

```scala
import sc._
import sqlContext._
import sqlContext.implicits._
import org.apache.spark.sql.types._
import com.google.common.base.Stopwatch

def benchmark(n: Int)(f: => Long) {
  val stopwatch = new Stopwatch()

  def run() = {
    stopwatch.reset()
    stopwatch.start()
    f
    stopwatch.stop()
    stopwatch.elapsedMillis()
  }

  val records = (0 until n).map(_ => run())

  (0 until n).foreach(i => println(s"Round $i: ${records(i)} ms"))
  println(s"Average: ${records.sum / n.toDouble} ms")
}

// Explicit casting is required because ScalaReflection can't inspect decimal 
precision
parallelize(1 to 2000000)
  .map(i => Tuple1(Decimal(i, 10, 0)))
  .toDF("dec")
  .select($"dec" cast DecimalType(10, 0))
  .registerTempTable("dec")

sql("CACHE TABLE dec")
val df = table("dec")

// Warm up
df.count()
df.count()

benchmark(5) {
  df.count()
}
```

With `FIXED_DECIMAL` column type:

- Round 0: 75 ms
- Round 1: 97 ms
- Round 2: 75 ms
- Round 3: 70 ms
- Round 4: 72 ms
- Average: 77.8 ms

Without `FIXED_DECIMAL` column type:

- Round 0: 1233 ms
- Round 1: 1170 ms
- Round 2: 1171 ms
- Round 3: 1141 ms
- Round 4: 1141 ms
- Average: 1171.2 ms

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png"; height=40 alt="Review on 
Reviewable"/>](https://reviewable.io/reviews/apache/spark/4938)
<!-- Reviewable:end -->

Author: Cheng Lian <[email protected]>

Closes #4938 from liancheng/decimal-column-type and squashes the following 
commits:

fef5338 [Cheng Lian] Updates fixed decimal column type related test cases
e08ab5b [Cheng Lian] Only resorts to FIXED_DECIMAL when the value can be held 
in a long
4db713d [Cheng Lian] Adds in-memory column type for fixed-precision decimals


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5be6b0e4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5be6b0e4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5be6b0e4

Branch: refs/heads/master
Commit: 5be6b0e4f48aca12fcd47c1b77c4675ad651c332
Parents: ee15404
Author: Cheng Lian <[email protected]>
Authored: Sat Mar 14 19:53:54 2015 +0800
Committer: Cheng Lian <[email protected]>
Committed: Sat Mar 14 19:53:54 2015 +0800

----------------------------------------------------------------------
 .../spark/sql/columnar/ColumnAccessor.scala     | 43 ++++++++-------
 .../spark/sql/columnar/ColumnBuilder.scala      | 39 ++++++++------
 .../apache/spark/sql/columnar/ColumnStats.scala | 17 ++++++
 .../apache/spark/sql/columnar/ColumnType.scala  | 55 +++++++++++++++-----
 .../columnar/InMemoryColumnarTableScan.scala    |  8 +--
 .../spark/sql/columnar/ColumnStatsSuite.scala   |  1 +
 .../spark/sql/columnar/ColumnTypeSuite.scala    | 46 +++++++++++-----
 .../spark/sql/columnar/ColumnarTestUtils.scala  | 23 ++++----
 .../columnar/InMemoryColumnarQuerySuite.scala   | 17 +++++-
 .../columnar/NullableColumnAccessorSuite.scala  |  3 +-
 .../columnar/NullableColumnBuilderSuite.scala   |  3 +-
 11 files changed, 179 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 91c4c10..b615eaa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
 
 import org.apache.spark.sql.catalyst.expressions.MutableRow
 import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
-import org.apache.spark.sql.types.{BinaryType, DataType, NativeType}
+import org.apache.spark.sql.types._
 
 /**
  * An `Iterator` like trait used to extract values from columnar byte buffer. 
When a value is
@@ -89,6 +89,9 @@ private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
 private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
   extends NativeColumnAccessor(buffer, FLOAT)
 
+private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: 
Int, scale: Int)
+  extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
+
 private[sql] class StringColumnAccessor(buffer: ByteBuffer)
   extends NativeColumnAccessor(buffer, STRING)
 
@@ -107,24 +110,28 @@ private[sql] class GenericColumnAccessor(buffer: 
ByteBuffer)
   with NullableColumnAccessor
 
 private[sql] object ColumnAccessor {
-  def apply(buffer: ByteBuffer): ColumnAccessor = {
+  def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
     val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
-    // The first 4 bytes in the buffer indicate the column type.
-    val columnTypeId = dup.getInt()
-
-    columnTypeId match {
-      case INT.typeId       => new IntColumnAccessor(dup)
-      case LONG.typeId      => new LongColumnAccessor(dup)
-      case FLOAT.typeId     => new FloatColumnAccessor(dup)
-      case DOUBLE.typeId    => new DoubleColumnAccessor(dup)
-      case BOOLEAN.typeId   => new BooleanColumnAccessor(dup)
-      case BYTE.typeId      => new ByteColumnAccessor(dup)
-      case SHORT.typeId     => new ShortColumnAccessor(dup)
-      case STRING.typeId    => new StringColumnAccessor(dup)
-      case DATE.typeId      => new DateColumnAccessor(dup)
-      case TIMESTAMP.typeId => new TimestampColumnAccessor(dup)
-      case BINARY.typeId    => new BinaryColumnAccessor(dup)
-      case GENERIC.typeId   => new GenericColumnAccessor(dup)
+
+    // The first 4 bytes in the buffer indicate the column type.  This field 
is not used now,
+    // because we always know the data type of the column ahead of time.
+    dup.getInt()
+
+    dataType match {
+      case IntegerType => new IntColumnAccessor(dup)
+      case LongType => new LongColumnAccessor(dup)
+      case FloatType => new FloatColumnAccessor(dup)
+      case DoubleType => new DoubleColumnAccessor(dup)
+      case BooleanType => new BooleanColumnAccessor(dup)
+      case ByteType => new ByteColumnAccessor(dup)
+      case ShortType => new ShortColumnAccessor(dup)
+      case StringType => new StringColumnAccessor(dup)
+      case BinaryType => new BinaryColumnAccessor(dup)
+      case DateType => new DateColumnAccessor(dup)
+      case TimestampType => new TimestampColumnAccessor(dup)
+      case DecimalType.Fixed(precision, scale) if precision < 19 =>
+        new FixedDecimalColumnAccessor(dup, precision, scale)
+      case _ => new GenericColumnAccessor(dup)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 3a4977b..d8d24a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -106,6 +106,13 @@ private[sql] class DoubleColumnBuilder extends 
NativeColumnBuilder(new DoubleCol
 
 private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new 
FloatColumnStats, FLOAT)
 
+private[sql] class FixedDecimalColumnBuilder(
+    precision: Int,
+    scale: Int)
+  extends NativeColumnBuilder(
+    new FixedDecimalColumnStats,
+    FIXED_DECIMAL(precision, scale))
+
 private[sql] class StringColumnBuilder extends NativeColumnBuilder(new 
StringColumnStats, STRING)
 
 private[sql] class DateColumnBuilder extends NativeColumnBuilder(new 
DateColumnStats, DATE)
@@ -139,25 +146,25 @@ private[sql] object ColumnBuilder {
   }
 
   def apply(
-      typeId: Int,
+      dataType: DataType,
       initialSize: Int = 0,
       columnName: String = "",
       useCompression: Boolean = false): ColumnBuilder = {
-
-    val builder = (typeId match {
-      case INT.typeId       => new IntColumnBuilder
-      case LONG.typeId      => new LongColumnBuilder
-      case FLOAT.typeId     => new FloatColumnBuilder
-      case DOUBLE.typeId    => new DoubleColumnBuilder
-      case BOOLEAN.typeId   => new BooleanColumnBuilder
-      case BYTE.typeId      => new ByteColumnBuilder
-      case SHORT.typeId     => new ShortColumnBuilder
-      case STRING.typeId    => new StringColumnBuilder
-      case BINARY.typeId    => new BinaryColumnBuilder
-      case GENERIC.typeId   => new GenericColumnBuilder
-      case DATE.typeId      => new DateColumnBuilder
-      case TIMESTAMP.typeId => new TimestampColumnBuilder
-    }).asInstanceOf[ColumnBuilder]
+    val builder: ColumnBuilder = dataType match {
+      case IntegerType => new IntColumnBuilder
+      case LongType => new LongColumnBuilder
+      case DoubleType => new DoubleColumnBuilder
+      case BooleanType => new BooleanColumnBuilder
+      case ByteType => new ByteColumnBuilder
+      case ShortType => new ShortColumnBuilder
+      case StringType => new StringColumnBuilder
+      case BinaryType => new BinaryColumnBuilder
+      case DateType => new DateColumnBuilder
+      case TimestampType => new TimestampColumnBuilder
+      case DecimalType.Fixed(precision, scale) if precision < 19 =>
+        new FixedDecimalColumnBuilder(precision, scale)
+      case _ => new GenericColumnBuilder
+    }
 
     builder.initialize(initialSize, columnName, useCompression)
     builder

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index cad0667..04047b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -181,6 +181,23 @@ private[sql] class FloatColumnStats extends ColumnStats {
   def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
 }
 
+private[sql] class FixedDecimalColumnStats extends ColumnStats {
+  protected var upper: Decimal = null
+  protected var lower: Decimal = null
+
+  override def gatherStats(row: Row, ordinal: Int): Unit = {
+    super.gatherStats(row, ordinal)
+    if (!row.isNullAt(ordinal)) {
+      val value = row(ordinal).asInstanceOf[Decimal]
+      if (upper == null || value.compareTo(upper) > 0) upper = value
+      if (lower == null || value.compareTo(lower) < 0) lower = value
+      sizeInBytes += FIXED_DECIMAL.defaultSize
+    }
+  }
+
+  override def collectedStatistics: Row = Row(lower, upper, nullCount, count, 
sizeInBytes)
+}
+
 private[sql] class IntColumnStats extends ColumnStats {
   protected var upper = Int.MinValue
   protected var lower = Int.MaxValue

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index db5bc0d..36ea1c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -373,6 +373,33 @@ private[sql] object TIMESTAMP extends 
NativeColumnType(TimestampType, 9, 12) {
   }
 }
 
+private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
+  extends NativeColumnType(
+    DecimalType(Some(PrecisionInfo(precision, scale))),
+    10,
+    FIXED_DECIMAL.defaultSize) {
+
+  override def extract(buffer: ByteBuffer): Decimal = {
+    Decimal(buffer.getLong(), precision, scale)
+  }
+
+  override def append(v: Decimal, buffer: ByteBuffer): Unit = {
+    buffer.putLong(v.toUnscaledLong)
+  }
+
+  override def getField(row: Row, ordinal: Int): Decimal = {
+    row(ordinal).asInstanceOf[Decimal]
+  }
+
+  override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = 
{
+    row(ordinal) = value
+  }
+}
+
+private[sql] object FIXED_DECIMAL {
+  val defaultSize = 8
+}
+
 private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
     typeId: Int,
     defaultSize: Int)
@@ -394,7 +421,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: 
DataType](
   }
 }
 
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 
16) {
+private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 
16) {
   override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): 
Unit = {
     row(ordinal) = value
   }
@@ -405,7 +432,7 @@ private[sql] object BINARY extends 
ByteArrayColumnType[BinaryType.type](10, 16)
 // Used to process generic objects (all types other than those listed above). 
Objects should be
 // serialized first before appending to the column `ByteBuffer`, and is also 
extracted as serialized
 // byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
+private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
   override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): 
Unit = {
     row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
   }
@@ -416,18 +443,20 @@ private[sql] object GENERIC extends 
ByteArrayColumnType[DataType](11, 16) {
 private[sql] object ColumnType {
   def apply(dataType: DataType): ColumnType[_, _] = {
     dataType match {
-      case IntegerType   => INT
-      case LongType      => LONG
-      case FloatType     => FLOAT
-      case DoubleType    => DOUBLE
-      case BooleanType   => BOOLEAN
-      case ByteType      => BYTE
-      case ShortType     => SHORT
-      case StringType    => STRING
-      case BinaryType    => BINARY
-      case DateType      => DATE
+      case IntegerType => INT
+      case LongType => LONG
+      case FloatType => FLOAT
+      case DoubleType => DOUBLE
+      case BooleanType => BOOLEAN
+      case ByteType => BYTE
+      case ShortType => SHORT
+      case StringType => STRING
+      case BinaryType => BINARY
+      case DateType => DATE
       case TimestampType => TIMESTAMP
-      case _             => GENERIC
+      case DecimalType.Fixed(precision, scale) if precision < 19 =>
+        FIXED_DECIMAL(precision, scale)
+      case _ => GENERIC
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 8944a32..387faee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -113,7 +113,7 @@ private[sql] case class InMemoryRelation(
           val columnBuilders = output.map { attribute =>
             val columnType = ColumnType(attribute.dataType)
             val initialBufferSize = columnType.defaultSize * batchSize
-            ColumnBuilder(columnType.typeId, initialBufferSize, 
attribute.name, useCompression)
+            ColumnBuilder(attribute.dataType, initialBufferSize, 
attribute.name, useCompression)
           }.toArray
 
           var rowCount = 0
@@ -274,8 +274,10 @@ private[sql] case class InMemoryColumnarTableScan(
       def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
         val rows = cacheBatches.flatMap { cachedBatch =>
           // Build column accessors
-          val columnAccessors = requestedColumnIndices.map { batch =>
-            ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
+          val columnAccessors = requestedColumnIndices.map { batchColumnIndex 
=>
+            ColumnAccessor(
+              relation.output(batchColumnIndex).dataType,
+              ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
           }
 
           // Extract rows via column accessors

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 581fccf..fec487f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -29,6 +29,7 @@ class ColumnStatsSuite extends FunSuite {
   testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, 
Long.MinValue, 0))
   testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, 
Float.MinValue, 0))
   testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, 
Double.MinValue, 0))
+  testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), 
Row(null, null, 0))
   testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
   testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, 
Int.MinValue, 0))
   testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 9ce8459..5f08834 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -33,8 +33,9 @@ class ColumnTypeSuite extends FunSuite with Logging {
 
   test("defaultSize") {
     val checks = Map(
-      INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, 
BOOLEAN -> 1,
-      STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
+      INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
+      FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, 
TIMESTAMP -> 12,
+      BINARY -> 16, GENERIC -> 16)
 
     checks.foreach { case (columnType, expectedSize) =>
       assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -56,15 +57,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
       }
     }
 
-    checkActualSize(INT,       Int.MaxValue,      4)
-    checkActualSize(SHORT,     Short.MaxValue,    2)
-    checkActualSize(LONG,      Long.MaxValue,     8)
-    checkActualSize(BYTE,      Byte.MaxValue,     1)
-    checkActualSize(DOUBLE,    Double.MaxValue,   8)
-    checkActualSize(FLOAT,     Float.MaxValue,    4)
-    checkActualSize(BOOLEAN,   true,              1)
-    checkActualSize(STRING,    "hello",           4 + 
"hello".getBytes("utf-8").length)
-    checkActualSize(DATE,      0,                 4)
+    checkActualSize(INT, Int.MaxValue, 4)
+    checkActualSize(SHORT, Short.MaxValue, 2)
+    checkActualSize(LONG, Long.MaxValue, 8)
+    checkActualSize(BYTE, Byte.MaxValue, 1)
+    checkActualSize(DOUBLE, Double.MaxValue, 8)
+    checkActualSize(FLOAT, Float.MaxValue, 4)
+    checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
+    checkActualSize(BOOLEAN, true, 1)
+    checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
+    checkActualSize(DATE, 0, 4)
     checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
 
     val binary = Array.fill[Byte](4)(0: Byte)
@@ -93,12 +95,20 @@ class ColumnTypeSuite extends FunSuite with Logging {
 
   testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
 
+  testNativeColumnType[DecimalType](
+    FIXED_DECIMAL(15, 10),
+    (buffer: ByteBuffer, decimal: Decimal) => {
+      buffer.putLong(decimal.toUnscaledLong)
+    },
+    (buffer: ByteBuffer) => {
+      Decimal(buffer.getLong(), 15, 10)
+    })
+
   testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
 
   testNativeColumnType[StringType.type](
     STRING,
     (buffer: ByteBuffer, string: String) => {
-
       val bytes = string.getBytes("utf-8")
       buffer.putInt(bytes.length)
       buffer.put(bytes)
@@ -206,4 +216,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
     if (sb.nonEmpty) sb.setLength(sb.length - 1)
     sb.toString()
   }
+
+  test("column type for decimal types with different precision") {
+    (1 to 18).foreach { i =>
+      assertResult(FIXED_DECIMAL(i, 0)) {
+        ColumnType(DecimalType(i, 0))
+      }
+    }
+
+    assertResult(GENERIC) {
+      ColumnType(DecimalType(19, 0))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index 60ed28c..c7a4084 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -24,7 +24,7 @@ import scala.util.Random
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{DataType, NativeType}
+import org.apache.spark.sql.types.{Decimal, DataType, NativeType}
 
 object ColumnarTestUtils {
   def makeNullRow(length: Int) = {
@@ -41,16 +41,17 @@ object ColumnarTestUtils {
     }
 
     (columnType match {
-      case BYTE      => (Random.nextInt(Byte.MaxValue * 2) - 
Byte.MaxValue).toByte
-      case SHORT     => (Random.nextInt(Short.MaxValue * 2) - 
Short.MaxValue).toShort
-      case INT       => Random.nextInt()
-      case LONG      => Random.nextLong()
-      case FLOAT     => Random.nextFloat()
-      case DOUBLE    => Random.nextDouble()
-      case STRING    => Random.nextString(Random.nextInt(32))
-      case BOOLEAN   => Random.nextBoolean()
-      case BINARY    => randomBytes(Random.nextInt(32))
-      case DATE      => Random.nextInt()
+      case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
+      case SHORT => (Random.nextInt(Short.MaxValue * 2) - 
Short.MaxValue).toShort
+      case INT => Random.nextInt()
+      case LONG => Random.nextLong()
+      case FLOAT => Random.nextFloat()
+      case DOUBLE => Random.nextDouble()
+      case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, 
precision, scale)
+      case STRING => Random.nextString(Random.nextInt(32))
+      case BOOLEAN => Random.nextBoolean()
+      case BINARY => randomBytes(Random.nextInt(32))
+      case DATE => Random.nextInt()
       case TIMESTAMP =>
         val timestamp = new Timestamp(Random.nextLong())
         timestamp.setNanos(Random.nextInt(999999999))

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 38b0f66..27dfabc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,11 +17,11 @@
 
 package org.apache.spark.sql.columnar
 
-import org.apache.spark.sql.functions._
 import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.expressions.Row
 import org.apache.spark.sql.test.TestSQLContext._
 import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.types.{DecimalType, Decimal}
 import org.apache.spark.sql.{QueryTest, TestData}
 import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
 
@@ -117,4 +117,19 @@ class InMemoryColumnarQuerySuite extends QueryTest {
     complexData.count()
     complexData.unpersist()
   }
+
+  test("decimal type") {
+    // Casting is required here because ScalaReflection can't capture decimal 
precision information.
+    val df = (1 to 10)
+      .map(i => Tuple1(Decimal(i, 15, 10)))
+      .toDF("dec")
+      .select($"dec" cast DecimalType(15, 10))
+
+    assert(df.schema.head.dataType === DecimalType(15, 10))
+
+    df.cache().registerTempTable("test_fixed_decimal")
+    checkAnswer(
+      sql("SELECT * FROM test_fixed_decimal"),
+      (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index f95c895..bb30535 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -42,7 +42,8 @@ class NullableColumnAccessorSuite extends FunSuite {
   import ColumnarTestUtils._
 
   Seq(
-    INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, 
DATE, TIMESTAMP
+    INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 
10), BINARY, GENERIC,
+    DATE, TIMESTAMP
   ).foreach {
     testNullableColumnAccessor(_)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5be6b0e4/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 80bd5c9..75a4749 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -38,7 +38,8 @@ class NullableColumnBuilderSuite extends FunSuite {
   import ColumnarTestUtils._
 
   Seq(
-    INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, 
DATE, TIMESTAMP
+    INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 
10), BINARY, GENERIC,
+    DATE, TIMESTAMP
   ).foreach {
     testNullableColumnBuilder(_)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to