Repository: spark
Updated Branches:
  refs/heads/master a57aadae8 -> bb1362eb3


[SPARK-10605][SQL] Create native collect_list/collect_set aggregates

## What changes were proposed in this pull request?
We currently use the Hive implementations for the collect_list/collect_set 
aggregate functions. This has a few major drawbacks: the use of HiveUDAF (which 
has quite a bit of overhead) and the lack of support for struct datatypes. This 
PR adds native implementation of these functions to Spark.

The size of the collected list/set may vary, this means we cannot use the fast, 
Tungsten, aggregation path to perform the aggregation, and that we fallback to 
the slower sort based path. Another big issue with these operators is that when 
the size of the collected list/set grows too large, we can start experiencing 
large GC pauzes and OOMEs.

This `collect*` aggregates implemented in this PR rely on the sort based 
aggregate path for correctness. They maintain their own internal buffer which 
holds the rows for one group at a time. The sortbased aggregation path is 
triggered by disabling `partialAggregation` for these aggregates (which is 
kinda funny); this technique is also employed in 
`org.apache.spark.sql.hiveHiveUDAFFunction`.

I have done some performance testing:
```scala
import org.apache.spark.sql.{Dataset, Row}

sql("create function collect_list2 as 
'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList'")

val df = range(0, 10000000).select($"id", (rand(213123L) * 
100000).cast("int").as("grp"))
df.select(countDistinct($"grp")).show

def benchmark(name: String, plan: Dataset[Row], maxItr: Int = 5): Unit = {
   // Do not measure planning.
   plan1.queryExecution.executedPlan

   // Execute the plan a number of times and average the result.
   val start = System.nanoTime
   var i = 0
   while (i < maxItr) {
     plan.rdd.foreach(row => Unit)
     i += 1
   }
   val time = (System.nanoTime - start) / (maxItr * 1000000L)
   println(s"[$name] $maxItr iterations completed in an average time of $time 
ms.")
}

val plan1 = df.groupBy($"grp").agg(collect_list($"id"))
val plan2 = df.groupBy($"grp").agg(callUDF("collect_list2", $"id"))

benchmark("Spark collect_list", plan1)
...
> [Spark collect_list] 5 iterations completed in an average time of 3371 ms.

benchmark("Hive collect_list", plan2)
...
> [Hive collect_list] 5 iterations completed in an average time of 9109 ms.
```
Performance is improved by a factor 2-3.

## How was this patch tested?
Added tests to `DataFrameAggregateSuite`.

Author: Herman van Hovell <[email protected]>

Closes #12874 from hvanhovell/implode.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bb1362eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bb1362eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bb1362eb

Branch: refs/heads/master
Commit: bb1362eb3b36b553dca246b95f59ba7fd8adcc8a
Parents: a57aada
Author: Herman van Hovell <[email protected]>
Authored: Thu May 12 13:56:00 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu May 12 13:56:00 2016 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   2 +
 .../expressions/aggregate/collect.scala         | 119 +++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  12 +-
 .../spark/sql/DataFrameAggregateSuite.scala     |  26 ++++
 .../spark/sql/hive/HiveSessionCatalog.scala     |  16 ---
 .../sql/hive/HiveDataFrameAnalyticsSuite.scala  |  11 --
 6 files changed, 149 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bb1362eb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ac05dd3..c459fe5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -252,6 +252,8 @@ object FunctionRegistry {
     expression[VarianceSamp]("variance"),
     expression[VariancePop]("var_pop"),
     expression[VarianceSamp]("var_samp"),
+    expression[CollectList]("collect_list"),
+    expression[CollectSet]("collect_set"),
 
     // string functions
     expression[Ascii]("ascii"),

http://git-wip-us.apache.org/repos/asf/spark/blob/bb1362eb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
new file mode 100644
index 0000000..1f4ff9c
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import scala.collection.generic.Growable
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/**
+ * The Collect aggregate function collects all seen expression values into a 
list of values.
+ *
+ * The operator is bound to the slower sort based aggregation path because the 
number of
+ * elements (and their memory usage) can not be determined in advance. This 
also means that the
+ * collected elements are stored on heap, and that too many elements can cause 
GC pauses and
+ * eventually Out of Memory Errors.
+ */
+abstract class Collect extends ImperativeAggregate {
+
+  val child: Expression
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = ArrayType(child.dataType)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  override def supportsPartial: Boolean = false
+
+  override def aggBufferAttributes: Seq[AttributeReference] = Nil
+
+  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
+
+  override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
+
+  protected[this] val buffer: Growable[Any] with Iterable[Any]
+
+  override def initialize(b: MutableRow): Unit = {
+    buffer.clear()
+  }
+
+  override def update(b: MutableRow, input: InternalRow): Unit = {
+    buffer += child.eval(input)
+  }
+
+  override def merge(buffer: MutableRow, input: InternalRow): Unit = {
+    sys.error("Collect cannot be used in partial aggregations.")
+  }
+
+  override def eval(input: InternalRow): Any = {
+    new GenericArrayData(buffer.toArray)
+  }
+}
+
+/**
+ * Collect a list of elements.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
+case class CollectList(
+    child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends Collect {
+
+  def this(child: Expression) = this(child, 0, 0)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "collect_list"
+
+  override protected[this] val buffer: mutable.ArrayBuffer[Any] = 
mutable.ArrayBuffer.empty
+}
+
+/**
+ * Collect a list of unique elements.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
+case class CollectSet(
+    child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0) extends Collect {
+
+  def this(child: Expression) = this(child, 0, 0)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def prettyName: String = "collect_set"
+
+  override protected[this] val buffer: mutable.HashSet[Any] = 
mutable.HashSet.empty
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bb1362eb/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3e295c2..07f5504 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -195,18 +195,14 @@ object functions {
   /**
    * Aggregate function: returns a list of objects with duplicates.
    *
-   * For now this is an alias for the collect_list Hive UDAF.
-   *
    * @group agg_funcs
    * @since 1.6.0
    */
-  def collect_list(e: Column): Column = callUDF("collect_list", e)
+  def collect_list(e: Column): Column = withAggregateFunction { 
CollectList(e.expr) }
 
   /**
    * Aggregate function: returns a list of objects with duplicates.
    *
-   * For now this is an alias for the collect_list Hive UDAF.
-   *
    * @group agg_funcs
    * @since 1.6.0
    */
@@ -215,18 +211,14 @@ object functions {
   /**
    * Aggregate function: returns a set of objects with duplicate elements 
eliminated.
    *
-   * For now this is an alias for the collect_set Hive UDAF.
-   *
    * @group agg_funcs
    * @since 1.6.0
    */
-  def collect_set(e: Column): Column = callUDF("collect_set", e)
+  def collect_set(e: Column): Column = withAggregateFunction { 
CollectSet(e.expr) }
 
   /**
    * Aggregate function: returns a set of objects with duplicate elements 
eliminated.
    *
-   * For now this is an alias for the collect_set Hive UDAF.
-   *
    * @group agg_funcs
    * @since 1.6.0
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/bb1362eb/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8a99866..69a9907 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -431,6 +431,32 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
       Row(null, null, null, null, null))
   }
 
+  test("collect functions") {
+    val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
+    checkAnswer(
+      df.select(collect_list($"a"), collect_list($"b")),
+      Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+    )
+    checkAnswer(
+      df.select(collect_set($"a"), collect_set($"b")),
+      Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+    )
+  }
+
+  test("collect functions structs") {
+    val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
+      .toDF("a", "x", "y")
+      .select($"a", struct($"x", $"y").as("b"))
+    checkAnswer(
+      df.select(collect_list($"a"), sort_array(collect_list($"b"))),
+      Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1))))
+    )
+    checkAnswer(
+      df.select(collect_set($"a"), sort_array(collect_set($"b"))),
+      Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1))))
+    )
+  }
+
   test("SPARK-14664: Decimal sum/avg over window should work.") {
     checkAnswer(
       spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),

http://git-wip-us.apache.org/repos/asf/spark/blob/bb1362eb/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 75a252c..4f8aac8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -222,20 +222,4 @@ private[sql] class HiveSessionCatalog(
         }
     }
   }
-
-  // Pre-load a few commonly used Hive built-in functions.
-  HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach {
-    case (functionName, clazz) =>
-      val builder = makeFunctionBuilder(functionName, clazz)
-      val info = new ExpressionInfo(clazz.getCanonicalName, functionName)
-      createTempFunction(functionName, info, builder, ignoreIfExists = false)
-  }
-}
-
-private[sql] object HiveSessionCatalog {
-  // This is the list of Hive's built-in functions that are commonly used and 
we want to
-  // pre-load when we create the FunctionRegistry.
-  val preloadedHiveBuiltinFunctions =
-    ("collect_set", 
classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) ::
-    ("collect_list", 
classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bb1362eb/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 57f96e7..cc41c04 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with 
TestHiveSingleton with
     )
   }
 
-  test("collect functions") {
-    checkAnswer(
-      testData.select(collect_list($"a"), collect_list($"b")),
-      Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
-    )
-    checkAnswer(
-      testData.select(collect_set($"a"), collect_set($"b")),
-      Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
-    )
-  }
-
   test("cube") {
     checkAnswer(
       testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to