Repository: spark
Updated Branches:
  refs/heads/branch-1.6 baae1ccc9 -> 70d4edda8


[SPARK-11275][SQL] Incorrect results when using rollup/cube

Fixes bug with grouping sets (including cube/rollup) where aggregates that 
included grouping expressions would return the wrong (null) result.

Also simplifies the analyzer rule a bit and leaves column pruning to the 
optimizer.

Added multiple unit tests to DataFrameAggregateSuite and verified it passes 
hive compatibility suite:
```
build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only 
org.apache.spark.sql.hive.execution.HiveCompatibilitySuite'
```

This is an alternative to pr https://github.com/apache/spark/pull/9419 but I 
think its better as it simplifies the analyzer rule instead of adding another 
special case to it.

Author: Andrew Ray <ray.and...@gmail.com>

Closes #9815 from aray/groupingset-agg-fix.

(cherry picked from commit 37cff1b1a79cad11277612cb9bc8bc2365cf5ff2)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70d4edda
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70d4edda
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70d4edda

Branch: refs/heads/branch-1.6
Commit: 70d4edda8f8e426a6286b83234fb685a88265f71
Parents: baae1cc
Author: Andrew Ray <ray.and...@gmail.com>
Authored: Thu Nov 19 15:11:30 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Nov 19 15:11:43 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 58 ++++++++----------
 .../catalyst/plans/logical/basicOperators.scala |  4 ++
 .../spark/sql/DataFrameAggregateSuite.scala     | 62 ++++++++++++++++++++
 3 files changed, 90 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70d4edda/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 84781cd..47962eb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -213,45 +213,35 @@ class Analyzer(
         GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
       case x: GroupingSets =>
         val gid = AttributeReference(VirtualColumn.groupingIdName, 
IntegerType, false)()
-        // We will insert another Projection if the GROUP BY keys contains the
-        // non-attribute expressions. And the top operators can references 
those
-        // expressions by its alias.
-        // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
-        //      SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a
-
-        // find all of the non-attribute expressions in the GROUP BY keys
-        val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()
-
-        // The pair of (the original GROUP BY key, associated attribute)
-        val groupByExprPairs = x.groupByExprs.map(_ match {
-          case e: NamedExpression => (e, e.toAttribute)
-          case other => {
-            val alias = Alias(other, other.toString)()
-            nonAttributeGroupByExpressions += alias // add the non-attributes 
expression alias
-            (other, alias.toAttribute)
-          }
-        })
-
-        // substitute the non-attribute expressions for aggregations.
-        val aggregation = x.aggregations.map(expr => expr.transformDown {
-          case e => 
groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
-        }.asInstanceOf[NamedExpression])
 
-        // substitute the group by expressions.
-        val newGroupByExprs = groupByExprPairs.map(_._2)
+        // Expand works by setting grouping expressions to null as determined 
by the bitmasks. To
+        // prevent these null values from being used in an aggregate instead 
of the original value
+        // we need to create new aliases for all group by expressions that 
will only be used for
+        // the intended purpose.
+        val groupByAliases: Seq[Alias] = x.groupByExprs.map {
+          case e: NamedExpression => Alias(e, e.name)()
+          case other => Alias(other, other.toString)()
+        }
 
-        val child = if (nonAttributeGroupByExpressions.length > 0) {
-          // insert additional projection if contains the
-          // non-attribute expressions in the GROUP BY keys
-          Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
-        } else {
-          x.child
+        val aggregations: Seq[NamedExpression] = x.aggregations.map {
+          // If an expression is an aggregate (contains a AggregateExpression) 
then we dont change
+          // it so that the aggregation is computed on the unmodified value of 
its argument
+          // expressions.
+          case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty 
=> expr
+          // If not then its a grouping expression and we need to use the 
modified (with nulls from
+          // Expand) value of the expression.
+          case expr => expr.transformDown {
+            case e => 
groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e)
+          }.asInstanceOf[NamedExpression]
         }
 
+        val child = Project(x.child.output ++ groupByAliases, x.child)
+        val groupByAttributes = groupByAliases.map(_.toAttribute)
+
         Aggregate(
-          newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
-          aggregation,
-          Expand(x.bitmasks, newGroupByExprs, gid, child))
+          groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+          aggregations,
+          Expand(x.bitmasks, groupByAttributes, gid, child))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/70d4edda/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 45630a5..0c44448 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode {
 
   override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
 
+  // Needs to be unresolved before its translated to Aggregate + Expand 
because output attributes
+  // will change in analysis.
+  override lazy val resolved: Boolean = false
+
   def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/70d4edda/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 71adf21..9c42f65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("rollup") {
+    checkAnswer(
+      courseSales.rollup("course", "year").sum("earnings"),
+      Row("Java", 2012, 20000.0) ::
+        Row("Java", 2013, 30000.0) ::
+        Row("Java", null, 50000.0) ::
+        Row("dotNET", 2012, 15000.0) ::
+        Row("dotNET", 2013, 48000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row(null, null, 113000.0) :: Nil
+    )
+  }
+
+  test("cube") {
+    checkAnswer(
+      courseSales.cube("course", "year").sum("earnings"),
+      Row("Java", 2012, 20000.0) ::
+        Row("Java", 2013, 30000.0) ::
+        Row("Java", null, 50000.0) ::
+        Row("dotNET", 2012, 15000.0) ::
+        Row("dotNET", 2013, 48000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row(null, 2012, 35000.0) ::
+        Row(null, 2013, 78000.0) ::
+        Row(null, null, 113000.0) :: Nil
+    )
+  }
+
+  test("rollup overlapping columns") {
+    checkAnswer(
+      testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - 
$"b") as "foo"),
+      Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 
1, 2) :: Row(5, 2, 1)
+        :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, 
null, 1)
+        :: Row(null, null, 3) :: Nil
+    )
+
+    checkAnswer(
+      testData2.rollup("a", "b").agg(sum("b")),
+      Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 
1, 1) :: Row(3, 2, 2)
+        :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
+        :: Row(null, null, 9) :: Nil
+    )
+  }
+
+  test("cube overlapping columns") {
+    checkAnswer(
+      testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+      Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 
1, 2) :: Row(5, 2, 1)
+        :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, 
null, 1)
+        :: Row(null, 1, 3) :: Row(null, 2, 0)
+        :: Row(null, null, 3) :: Nil
+    )
+
+    checkAnswer(
+      testData2.cube("a", "b").agg(sum("b")),
+      Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 
1, 1) :: Row(3, 2, 2)
+        :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
+        :: Row(null, 1, 3) :: Row(null, 2, 6)
+        :: Row(null, null, 9) :: Nil
+    )
+  }
+
   test("spark.sql.retainGroupColumns config") {
     checkAnswer(
       testData2.groupBy("a").agg(sum($"b")),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to