Repository: spark
Updated Branches:
refs/heads/branch-2.0 ac6e9a8d9 -> 31ea3c7bd
[SPARK-10605][SQL] Create native collect_list/collect_set aggregates
## What changes were proposed in this pull request?
We currently use the Hive implementations for the collect_list/collect_set
aggregate functions. This has a few major drawbacks: the use of HiveUDAF (which
has quite a bit of overhead) and the lack of support for struct datatypes. This
PR adds native implementation of these functions to Spark.
The size of the collected list/set may vary, this means we cannot use the fast,
Tungsten, aggregation path to perform the aggregation, and that we fallback to
the slower sort based path. Another big issue with these operators is that when
the size of the collected list/set grows too large, we can start experiencing
large GC pauzes and OOMEs.
This `collect*` aggregates implemented in this PR rely on the sort based
aggregate path for correctness. They maintain their own internal buffer which
holds the rows for one group at a time. The sortbased aggregation path is
triggered by disabling `partialAggregation` for these aggregates (which is
kinda funny); this technique is also employed in
`org.apache.spark.sql.hiveHiveUDAFFunction`.
I have done some performance testing:
```scala
import org.apache.spark.sql.{Dataset, Row}
sql("create function collect_list2 as
'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList'")
val df = range(0, 10000000).select($"id", (rand(213123L) *
100000).cast("int").as("grp"))
df.select(countDistinct($"grp")).show
def benchmark(name: String, plan: Dataset[Row], maxItr: Int = 5): Unit = {
// Do not measure planning.
plan1.queryExecution.executedPlan
// Execute the plan a number of times and average the result.
val start = System.nanoTime
var i = 0
while (i < maxItr) {
plan.rdd.foreach(row => Unit)
i += 1
}
val time = (System.nanoTime - start) / (maxItr * 1000000L)
println(s"[$name] $maxItr iterations completed in an average time of $time
ms.")
}
val plan1 = df.groupBy($"grp").agg(collect_list($"id"))
val plan2 = df.groupBy($"grp").agg(callUDF("collect_list2", $"id"))
benchmark("Spark collect_list", plan1)
...
> [Spark collect_list] 5 iterations completed in an average time of 3371 ms.
benchmark("Hive collect_list", plan2)
...
> [Hive collect_list] 5 iterations completed in an average time of 9109 ms.
```
Performance is improved by a factor 2-3.
## How was this patch tested?
Added tests to `DataFrameAggregateSuite`.
Author: Herman van Hovell <[email protected]>
Closes #12874 from hvanhovell/implode.
(cherry picked from commit bb1362eb3b36b553dca246b95f59ba7fd8adcc8a)
Signed-off-by: Reynold Xin <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31ea3c7b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31ea3c7b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31ea3c7b
Branch: refs/heads/branch-2.0
Commit: 31ea3c7bde94f5bcca1db601f9c16c36c56cef73
Parents: ac6e9a8
Author: Herman van Hovell <[email protected]>
Authored: Thu May 12 13:56:00 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu May 12 13:56:15 2016 -0700
----------------------------------------------------------------------
.../catalyst/analysis/FunctionRegistry.scala | 2 +
.../expressions/aggregate/collect.scala | 119 +++++++++++++++++++
.../scala/org/apache/spark/sql/functions.scala | 12 +-
.../spark/sql/DataFrameAggregateSuite.scala | 26 ++++
.../spark/sql/hive/HiveSessionCatalog.scala | 16 ---
.../sql/hive/HiveDataFrameAnalyticsSuite.scala | 11 --
6 files changed, 149 insertions(+), 37 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/31ea3c7b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ac05dd3..c459fe5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -252,6 +252,8 @@ object FunctionRegistry {
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
+ expression[CollectList]("collect_list"),
+ expression[CollectSet]("collect_set"),
// string functions
expression[Ascii]("ascii"),
http://git-wip-us.apache.org/repos/asf/spark/blob/31ea3c7b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
new file mode 100644
index 0000000..1f4ff9c
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import scala.collection.generic.Growable
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+/**
+ * The Collect aggregate function collects all seen expression values into a
list of values.
+ *
+ * The operator is bound to the slower sort based aggregation path because the
number of
+ * elements (and their memory usage) can not be determined in advance. This
also means that the
+ * collected elements are stored on heap, and that too many elements can cause
GC pauses and
+ * eventually Out of Memory Errors.
+ */
+abstract class Collect extends ImperativeAggregate {
+
+ val child: Expression
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = ArrayType(child.dataType)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+ override def supportsPartial: Boolean = false
+
+ override def aggBufferAttributes: Seq[AttributeReference] = Nil
+
+ override def aggBufferSchema: StructType =
StructType.fromAttributes(aggBufferAttributes)
+
+ override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
+
+ protected[this] val buffer: Growable[Any] with Iterable[Any]
+
+ override def initialize(b: MutableRow): Unit = {
+ buffer.clear()
+ }
+
+ override def update(b: MutableRow, input: InternalRow): Unit = {
+ buffer += child.eval(input)
+ }
+
+ override def merge(buffer: MutableRow, input: InternalRow): Unit = {
+ sys.error("Collect cannot be used in partial aggregations.")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ new GenericArrayData(buffer.toArray)
+ }
+}
+
+/**
+ * Collect a list of elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
+case class CollectList(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends Collect {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def prettyName: String = "collect_list"
+
+ override protected[this] val buffer: mutable.ArrayBuffer[Any] =
mutable.ArrayBuffer.empty
+}
+
+/**
+ * Collect a list of unique elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
+case class CollectSet(
+ child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends Collect {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ override def prettyName: String = "collect_set"
+
+ override protected[this] val buffer: mutable.HashSet[Any] =
mutable.HashSet.empty
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/31ea3c7b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3e295c2..07f5504 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -195,18 +195,14 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * For now this is an alias for the collect_list Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
- def collect_list(e: Column): Column = callUDF("collect_list", e)
+ def collect_list(e: Column): Column = withAggregateFunction {
CollectList(e.expr) }
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * For now this is an alias for the collect_list Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
@@ -215,18 +211,14 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements
eliminated.
*
- * For now this is an alias for the collect_set Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
- def collect_set(e: Column): Column = callUDF("collect_set", e)
+ def collect_set(e: Column): Column = withAggregateFunction {
CollectSet(e.expr) }
/**
* Aggregate function: returns a set of objects with duplicate elements
eliminated.
*
- * For now this is an alias for the collect_set Hive UDAF.
- *
* @group agg_funcs
* @since 1.6.0
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/31ea3c7b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8a99866..69a9907 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -431,6 +431,32 @@ class DataFrameAggregateSuite extends QueryTest with
SharedSQLContext {
Row(null, null, null, null, null))
}
+ test("collect functions") {
+ val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
+ checkAnswer(
+ df.select(collect_list($"a"), collect_list($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), collect_set($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+ )
+ }
+
+ test("collect functions structs") {
+ val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
+ .toDF("a", "x", "y")
+ .select($"a", struct($"x", $"y").as("b"))
+ checkAnswer(
+ df.select(collect_list($"a"), sort_array(collect_list($"b"))),
+ Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1))))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), sort_array(collect_set($"b"))),
+ Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1))))
+ )
+ }
+
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
http://git-wip-us.apache.org/repos/asf/spark/blob/31ea3c7b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
----------------------------------------------------------------------
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 75a252c..4f8aac8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -222,20 +222,4 @@ private[sql] class HiveSessionCatalog(
}
}
}
-
- // Pre-load a few commonly used Hive built-in functions.
- HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach {
- case (functionName, clazz) =>
- val builder = makeFunctionBuilder(functionName, clazz)
- val info = new ExpressionInfo(clazz.getCanonicalName, functionName)
- createTempFunction(functionName, info, builder, ignoreIfExists = false)
- }
-}
-
-private[sql] object HiveSessionCatalog {
- // This is the list of Hive's built-in functions that are commonly used and
we want to
- // pre-load when we create the FunctionRegistry.
- val preloadedHiveBuiltinFunctions =
- ("collect_set",
classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) ::
- ("collect_list",
classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil
}
http://git-wip-us.apache.org/repos/asf/spark/blob/31ea3c7b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 57f96e7..cc41c04 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with
TestHiveSingleton with
)
}
- test("collect functions") {
- checkAnswer(
- testData.select(collect_list($"a"), collect_list($"b")),
- Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
- )
- checkAnswer(
- testData.select(collect_set($"a"), collect_set($"b")),
- Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
- )
- }
-
test("cube") {
checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]