Repository: spark
Updated Branches:
  refs/heads/master d2383fb5f -> 6d4e854ff


[SPARK-6367][SQL] Use the proper data type for those expressions that are 
hijacking existing data types.

This PR adds internal UDTs for expressions that are hijacking existing data 
types.
The following UDTs are added:
* `HyperLogLogUDT` (`BinaryType` as the SQL type) for 
`ApproxCountDistinctPartition`
* `OpenHashSetUDT` (`ArrayType` as the SQL type) for `CollectHashSet`, 
`NewSet`, `AddItemToSet`, and `CombineSets`.

I am also adding more unit tests for aggregation with code gen enabled.

JIRA: https://issues.apache.org/jira/browse/SPARK-6367

Author: Yin Huai <[email protected]>

Closes #5094 from yhuai/expressionType and squashes the following commits:

8bcd11a [Yin Huai] Return types.
61a1d66 [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
expressionType
e8b4599 [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
expressionType
2753156 [Yin Huai] Ignore aggregations having sum functions for now.
b5eb259 [Yin Huai] Case object for HyperLogLog type.
00ebdbd [Yin Huai] deserialize/serialize.
54b87ae [Yin Huai] Add UDTs for expressions that return HyperLogLog and 
OpenHashSet.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d4e854f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d4e854f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d4e854f

Branch: refs/heads/master
Commit: 6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8
Parents: d2383fb
Author: Yin Huai <[email protected]>
Authored: Sat Apr 11 19:26:15 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Sat Apr 11 19:26:15 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/aggregates.scala   | 24 ++++++++++++--
 .../expressions/codegen/CodeGenerator.scala     |  4 +--
 .../spark/sql/catalyst/expressions/sets.scala   | 35 +++++++++++++++++---
 .../sql/execution/GeneratedAggregate.scala      | 12 ++++---
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 12 ++++---
 .../apache/spark/sql/UserDefinedTypeSuite.scala | 24 +++++++++++++-
 6 files changed, 91 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6d4e854f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 406de38..14a8550 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -189,9 +189,10 @@ case class CollectHashSet(expressions: Seq[Expression]) 
extends AggregateExpress
 
   override def children: Seq[Expression] = expressions
   override def nullable: Boolean = false
-  override def dataType: ArrayType = ArrayType(expressions.head.dataType)
+  override def dataType: OpenHashSetUDT = new 
OpenHashSetUDT(expressions.head.dataType)
   override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
-  override def newInstance(): CollectHashSetFunction = new 
CollectHashSetFunction(expressions, this)
+  override def newInstance(): CollectHashSetFunction =
+    new CollectHashSetFunction(expressions, this)
 }
 
 case class CollectHashSetFunction(
@@ -250,11 +251,28 @@ case class CombineSetsAndCountFunction(
   override def eval(input: Row): Any = seen.size.toLong
 }
 
+/** The data type of ApproxCountDistinctPartition since its output is a 
HyperLogLog object. */
+private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
+
+  override def sqlType: DataType = BinaryType
+
+  /** Since we are using HyperLogLog internally, usually it will not be 
called. */
+  override def serialize(obj: Any): Array[Byte] =
+    obj.asInstanceOf[HyperLogLog].getBytes
+
+
+  /** Since we are using HyperLogLog internally, usually it will not be 
called. */
+  override def deserialize(datum: Any): HyperLogLog =
+    HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
+
+  override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
+}
+
 case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
   extends AggregateExpression with trees.UnaryNode[Expression] {
 
   override def nullable: Boolean = false
-  override def dataType: DataType = child.dataType
+  override def dataType: DataType = HyperLogLogUDT
   override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
   override def newInstance(): ApproxCountDistinctPartitionFunction = {
     new ApproxCountDistinctPartitionFunction(child, this, relativeSD)

http://git-wip-us.apache.org/repos/asf/spark/blob/6d4e854f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d1abf3c..aac56e1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -464,7 +464,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
         val itemEval = expressionEvaluator(item)
         val setEval = expressionEvaluator(set)
 
-        val ArrayType(elementType, _) = set.dataType
+        val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
 
         itemEval.code ++ setEval.code ++
         q"""
@@ -482,7 +482,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: 
AnyRef] extends Loggin
         val leftEval = expressionEvaluator(left)
         val rightEval = expressionEvaluator(right)
 
-        val ArrayType(elementType, _) = left.dataType
+        val elementType = 
left.dataType.asInstanceOf[OpenHashSetUDT].elementType
 
         leftEval.code ++ rightEval.code ++
         q"""

http://git-wip-us.apache.org/repos/asf/spark/blob/6d4e854f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 35faa00..4c44182 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -20,6 +20,33 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
 
+/** The data type for expressions returning an OpenHashSet as the result. */
+private[sql] class OpenHashSetUDT(
+    val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
+
+  override def sqlType: DataType = ArrayType(elementType)
+
+  /** Since we are using OpenHashSet internally, usually it will not be 
called. */
+  override def serialize(obj: Any): Seq[Any] = {
+    obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
+  }
+
+  /** Since we are using OpenHashSet internally, usually it will not be 
called. */
+  override def deserialize(datum: Any): OpenHashSet[Any] = {
+    val iterator = datum.asInstanceOf[Seq[Any]].iterator
+    val set = new OpenHashSet[Any]
+    while(iterator.hasNext) {
+      set.add(iterator.next())
+    }
+
+    set
+  }
+
+  override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
+
+  private[spark] override def asNullable: OpenHashSetUDT = this
+}
+
 /**
  * Creates a new set of the specified type
  */
@@ -28,9 +55,7 @@ case class NewSet(elementType: DataType) extends 
LeafExpression {
 
   override def nullable: Boolean = false
 
-  // We are currently only using these Expressions internally for aggregation. 
 However, if we ever
-  // expose these to users we'll want to create a proper type instead of 
hijacking ArrayType.
-  override def dataType: DataType = ArrayType(elementType)
+  override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
 
   override def eval(input: Row): Any = {
     new OpenHashSet[Any]()
@@ -50,7 +75,7 @@ case class AddItemToSet(item: Expression, set: Expression) 
extends Expression {
 
   override def nullable: Boolean = set.nullable
 
-  override def dataType: DataType = set.dataType
+  override def dataType: OpenHashSetUDT = 
set.dataType.asInstanceOf[OpenHashSetUDT]
 
   override def eval(input: Row): Any = {
     val itemEval = item.eval(input)
@@ -80,7 +105,7 @@ case class CombineSets(left: Expression, right: Expression) 
extends BinaryExpres
 
   override def nullable: Boolean = left.nullable || right.nullable
 
-  override def dataType: DataType = left.dataType
+  override def dataType: OpenHashSetUDT = 
left.dataType.asInstanceOf[OpenHashSetUDT]
 
   override def symbol: String = "++="
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6d4e854f/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 861a2c2..3c58e93 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -68,6 +68,8 @@ case class GeneratedAggregate(
       a.collect { case agg: AggregateExpression => agg}
     }
 
+    // If you add any new function support, please add tests in 
org.apache.spark.sql.SQLQuerySuite
+    // (in test "aggregation with codegen").
     val computeFunctions = aggregatesToCompute.map {
       case c @ Count(expr) =>
         // If we're evaluating UnscaledValue(x), we can do Count on x 
directly, since its
@@ -208,7 +210,8 @@ case class GeneratedAggregate(
           currentMax)
 
       case CollectHashSet(Seq(expr)) =>
-        val set = AttributeReference("hashSet", ArrayType(expr.dataType), 
nullable = false)()
+        val set =
+          AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), 
nullable = false)()
         val initialValue = NewSet(expr.dataType)
         val addToSet = AddItemToSet(expr, set)
 
@@ -219,9 +222,10 @@ case class GeneratedAggregate(
           set)
 
       case CombineSetsAndCount(inputSet) =>
-        val ArrayType(inputType, _) = inputSet.dataType
-        val set = AttributeReference("hashSet", inputSet.dataType, nullable = 
false)()
-        val initialValue = NewSet(inputType)
+        val elementType = 
inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
+        val set =
+          AttributeReference("hashSet", new OpenHashSetUDT(elementType), 
nullable = false)()
+        val initialValue = NewSet(elementType)
         val collectSets = CombineSets(set, inputSet)
 
         AggregateEvaluation(

http://git-wip-us.apache.org/repos/asf/spark/blob/6d4e854f/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fb8fc6d..5e453e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.execution.GeneratedAggregate
 import org.apache.spark.sql.test.TestSQLContext
 import org.scalatest.BeforeAndAfterAll
 
@@ -151,10 +152,10 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll {
       "SELECT count(distinct key) FROM testData3x",
       Row(100) :: Nil)
     // SUM
-     testCodeGen(
-       "SELECT value, sum(key) FROM testData3x GROUP BY value",
-       (1 to 100).map(i => Row(i.toString, 3 * i)))
-     testCodeGen(
+    testCodeGen(
+      "SELECT value, sum(key) FROM testData3x GROUP BY value",
+      (1 to 100).map(i => Row(i.toString, 3 * i)))
+    testCodeGen(
       "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",      
       Row(5050 * 3, 5050 * 3.0) :: Nil)
     // AVERAGE
@@ -192,10 +193,11 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll {
     testCodeGen(
       "SELECT  sum('a'), avg('a'), count(null) FROM testData",
       Row(0, null, 0) :: Nil)
-      
+
     dropTempTable("testData3x")
     setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
   }
+
   test("Add Parser of SQL COALESCE()") {
     checkAnswer(
       sql("""SELECT COALESCE(1, 2)"""),

http://git-wip-us.apache.org/repos/asf/spark/blob/6d4e854f/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 902da5c..2672e20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -23,13 +23,16 @@ import org.apache.spark.util.Utils
 
 import scala.beans.{BeanInfo, BeanProperty}
 
+import com.clearspring.analytics.stream.cardinality.HyperLogLog
+
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, 
HyperLogLogUDT}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
 import org.apache.spark.sql.test.TestSQLContext.implicits._
 import org.apache.spark.sql.types._
-
+import org.apache.spark.util.collection.OpenHashSet
 
 @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
 private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable 
{
@@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest {
     
df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
     
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
   }
+
+  test("HyperLogLogUDT") {
+    val hyperLogLogUDT = HyperLogLogUDT
+    val hyperLogLog = new HyperLogLog(0.4)
+    (1 to 10).foreach(i => hyperLogLog.offer(Row(i)))
+
+    val actual = 
hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog))
+    assert(actual.cardinality() === hyperLogLog.cardinality())
+    assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes))
+  }
+
+  test("OpenHashSetUDT") {
+    val openHashSetUDT = new OpenHashSetUDT(IntegerType)
+    val set = new OpenHashSet[Int]
+    (1 to 10).foreach(i => set.add(i))
+
+    val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
+    assert(actual.iterator.toSet === set.iterator.toSet)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to