Repository: spark
Updated Branches:
  refs/heads/branch-1.5 15d2736af -> 94e6d8f72


[SPARK-10389] [SQL] [1.5] support order by non-attribute grouping expression on 
Aggregate

backport https://github.com/apache/spark/pull/8548 to 1.5

Author: Wenchen Fan <[email protected]>

Closes #9102 from cloud-fan/branch-1.5.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94e6d8f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94e6d8f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94e6d8f7

Branch: refs/heads/branch-1.5
Commit: 94e6d8f72bd2752f571d134146022db264b19c3b
Parents: 15d2736
Author: Wenchen Fan <[email protected]>
Authored: Tue Oct 13 16:16:08 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Tue Oct 13 16:16:08 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 72 ++++++++++----------
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 10 +++
 2 files changed, 47 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94e6d8f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6e7353f..47db448 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -576,43 +576,47 @@ class Analyzer(
           filter
         }
 
-      case sort @ Sort(sortOrder, global,
-             aggregate @ Aggregate(grouping, originalAggExprs, child))
+      case sort @ Sort(sortOrder, global, aggregate: Aggregate)
         if aggregate.resolved && !sort.resolved =>
 
         // Try resolving the ordering as though it is in the aggregate clause.
         try {
-          val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
-          val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
-          val resolvedOperator: Aggregate = 
execute(aggregatedOrdering).asInstanceOf[Aggregate]
-          def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
-
-          // Expressions that have an aggregate can be pushed down.
-          val needsAggregate = 
resolvedAggregateOrdering.exists(containsAggregate)
-
-          // Attribute references, that are missing from the order but are 
present in the grouping
-          // expressions can also be pushed down.
-          val requiredAttributes = 
resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
-          val missingAttributes = requiredAttributes -- aggregate.outputSet
-          val validPushdownAttributes =
-            missingAttributes.filter(a => grouping.exists(a.semanticEquals))
-
-          // If resolution was successful and we see the ordering either has 
an aggregate in it or
-          // it is missing something that is projected away by the aggregate, 
add the ordering
-          // the original aggregate operator.
-          if (resolvedOperator.resolved && (needsAggregate || 
validPushdownAttributes.nonEmpty)) {
-            val evaluatedOrderings: Seq[SortOrder] = 
sortOrder.zip(resolvedAggregateOrdering).map {
-              case (order, evaluated) => order.copy(child = 
evaluated.toAttribute)
-            }
-            val aggExprsWithOrdering: Seq[NamedExpression] =
-              resolvedAggregateOrdering ++ originalAggExprs
-
-            Project(aggregate.output,
-              Sort(evaluatedOrderings, global,
-                aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
-          } else {
-            sort
+          val aliasedOrdering = sortOrder.map(o => Alias(o.child, 
"aggOrder")())
+          val aggregatedOrdering = aggregate.copy(aggregateExpressions = 
aliasedOrdering)
+          val resolvedAggregate: Aggregate = 
execute(aggregatedOrdering).asInstanceOf[Aggregate]
+          val resolvedAliasedOrdering: Seq[Alias] =
+            resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
+
+          // If we pass the analysis check, then the ordering expressions 
should only reference to
+          // aggregate expressions or grouping expressions, and it's safe to 
push them down to
+          // Aggregate.
+          checkAnalysis(resolvedAggregate)
+
+          val originalAggExprs = aggregate.aggregateExpressions.map(
+            
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+
+          // If the ordering expression is same with original aggregate 
expression, we don't need
+          // to push down this ordering expression and can reference the 
original aggregate
+          // expression instead.
+          val needsPushDown = ArrayBuffer.empty[NamedExpression]
+          val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
+            case (evaluated, order) =>
+              val index = originalAggExprs.indexWhere {
+                case Alias(child, _) => child semanticEquals evaluated.child
+                case other => other semanticEquals evaluated.child
+              }
+
+              if (index == -1) {
+                needsPushDown += evaluated
+                order.copy(child = evaluated.toAttribute)
+              } else {
+                order.copy(child = originalAggExprs(index).toAttribute)
+              }
           }
+
+          Project(aggregate.output,
+            Sort(evaluatedOrderings, global,
+              aggregate.copy(aggregateExpressions = originalAggExprs ++ 
needsPushDown)))
         } catch {
           // Attempting to resolve in the aggregate can result in ambiguity.  
When this happens,
           // just return the original plan.
@@ -621,9 +625,7 @@ class Analyzer(
     }
 
     protected def containsAggregate(condition: Expression): Boolean = {
-      condition
-        .collect { case ae: AggregateExpression => ae }
-        .nonEmpty
+      condition.find(_.isInstanceOf[AggregateExpression]).isDefined
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/94e6d8f7/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4f31bd0..598a6ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1745,4 +1745,14 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
         df1.withColumn("diff", lit(0)))
     }
   }
+
+  test("SPARK-10389: order by non-attribute grouping expression on Aggregate") 
{
+    withTempTable("src") {
+      Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
+      checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY 
key + 1"),
+        Seq(Row(1), Row(1)))
+      checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY 
(key + 1) * 2"),
+        Seq(Row(1), Row(1)))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to