Repository: spark Updated Branches: refs/heads/master 1f7b3d9dc -> dd9ae7945
[SPARK-9351] [SQL] remove literals from grouping expressions in Aggregate literals in grouping expressions have no effect at all, only make our grouping key bigger, so we should remove them in Optimizer. I also make old and new aggregation code consistent about literals in grouping here. In old aggregation, actually literals in grouping are already removed but new aggregation is not. So I explicitly make it a rule in Optimizer. Author: Wenchen Fan <[email protected]> Closes #7583 from cloud-fan/minor and squashes the following commits: 471adff [Wenchen Fan] add test 0839925 [Wenchen Fan] use transformDown when rewrite final result expressions Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dd9ae794 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dd9ae794 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dd9ae794 Branch: refs/heads/master Commit: dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6 Parents: 1f7b3d9 Author: Wenchen Fan <[email protected]> Authored: Mon Jul 27 11:23:29 2015 -0700 Committer: Michael Armbrust <[email protected]> Committed: Mon Jul 27 11:23:29 2015 -0700 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 17 +++++- .../spark/sql/catalyst/planning/patterns.scala | 4 +- .../optimizer/AggregateOptimizeSuite.scala | 57 ++++++++++++++++++++ .../ReplaceDistinctWithAggregateSuite.scala | 42 --------------- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +++++++--- 5 files changed, 97 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/dd9ae794/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b59f800..813c620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,8 +36,9 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Distinct", FixedPoint(100), - ReplaceDistinctWithAggregate) :: + Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -799,3 +800,15 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { case Distinct(child) => Aggregate(child.output, child.output, child) } } + +/** + * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. + */ +object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, _, _) => + val newGrouping = grouping.filter(!_.foldable) + a.copy(groupingExpressions = newGrouping) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dd9ae794/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1e7b2a5..b9ca712 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -144,14 +144,14 @@ object PartialAggregation { // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + groupingExpressions.map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation http://git-wip-us.apache.org/repos/asf/spark/blob/dd9ae794/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala new file mode 100644 index 0000000..2d080b9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class AggregateOptimizeSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Nil + } + + test("replace distinct with aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } + + test("remove literals in grouping expression") { + val input = LocalRelation('a.int, 'b.int) + + val query = + input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(query) + + val correctAnswer = input.groupBy('a)(sum('b)) + + comparePlans(optimized, correctAnswer) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dd9ae794/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala deleted file mode 100644 index df29a62..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor - -class ReplaceDistinctWithAggregateSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil - } - - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/dd9ae794/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8cef0b3..358e319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -463,12 +463,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + def literalInAggTest(): Unit = { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + literalInAggTest() + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + literalInAggTest() + } } test("aggregates with nulls") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
