Repository: spark Updated Branches: refs/heads/master 13c155958 -> 0ae96495d
[SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle ## What changes were proposed in this pull request? `ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle. ## How was this patch tested? Added Jenkins test. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #19501 from viirya/SPARK-22223. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ae96495 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ae96495 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ae96495 Branch: refs/heads/master Commit: 0ae96495dedb54b3b6bae0bd55560820c5ca29a2 Parents: 13c1559 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Mon Oct 16 13:37:58 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Mon Oct 16 13:37:58 2017 +0800 ---------------------------------------------------------------------- .../aggregate/ObjectHashAggregateExec.scala | 2 ++ .../spark/sql/DataFrameAggregateSuite.scala | 30 ++++++++++++++++++++ 2 files changed, 32 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0ae96495/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index ec3f9a0..66955b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -95,6 +95,8 @@ case class ObjectHashAggregateExec( } } + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") http://git-wip-us.apache.org/repos/asf/spark/blob/0ae96495/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8549eac..06848e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), Seq(Row(3, 4, 9))) } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + + val objHashAggDF = df + .withColumn("d", expr("(a, b, c)")) + .groupBy("a", "b").agg(collect_list("d").as("e")) + .withColumn("f", expr("(b, e)")) + .groupBy("a").agg(collect_list("f").as("g")) + val aggPlan = objHashAggDF.queryExecution.executedPlan + + val sortAggPlans = aggPlan.collect { + case sortAgg: SortAggregateExec => sortAgg + } + assert(sortAggPlans.isEmpty) + + val objHashAggPlans = aggPlan.collect { + case objHashAgg: ObjectHashAggregateExec => objHashAgg + } + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = aggPlan.collect { + case shuffle: ShuffleExchangeExec => shuffle + } + assert(exchangePlans.length == 1) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org