Repository: spark
Updated Branches:
  refs/heads/branch-1.2 8608ff598 -> 1d7ee2b79


[SPARK-4318][SQL] Fix empty sum distinct.

Executing sum distinct for empty table throws 
`java.lang.UnsupportedOperationException: empty.reduceLeft`.

Author: Takuya UESHIN <[email protected]>

Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits:

8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318
66fdb0a [Takuya UESHIN] Re-refine aggregate functions.
6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate.
d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate.
1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions.
917e533 [Takuya UESHIN] Use aggregate instead of groupBy().
1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation.
a5a57d2 [Takuya UESHIN] Fix empty Average.
22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct.
65b7dd2 [Takuya UESHIN] Fix empty sum distinct.

(cherry picked from commit 2c2e7a44db2ebe44121226f3eac924a0668b991a)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d7ee2b7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d7ee2b7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d7ee2b7

Branch: refs/heads/branch-1.2
Commit: 1d7ee2b79b23f08f73a6d53f41ac8fa140b91c19
Parents: 8608ff5
Author: Takuya UESHIN <[email protected]>
Authored: Thu Nov 20 15:41:24 2014 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Thu Nov 20 15:41:36 2014 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/aggregates.scala   | 103 ++++++++++++++-----
 .../sql/execution/GeneratedAggregate.scala      |  68 +++++++-----
 .../org/apache/spark/sql/DslQuerySuite.scala    |  65 +++++++++++-
 .../scala/org/apache/spark/sql/TestData.scala   |  11 ++
 4 files changed, 195 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d7ee2b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 3ceb5ec..0cd9086 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -158,7 +158,7 @@ case class Count(child: Expression) extends 
PartialAggregate with trees.UnaryNod
 
   override def asPartial: SplitEvaluation = {
     val partialCount = Alias(Count(child), "PartialCount")()
-    SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
+    SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), 
partialCount :: Nil)
   }
 
   override def newInstance() = new CountFunction(child, this)
@@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, 
relativeSD: Double = 0.05)
 
 case class Average(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
 
-  override def nullable = false
+  override def nullable = true
 
   override def dataType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
@@ -299,12 +299,12 @@ case class Average(child: Expression) extends 
PartialAggregate with trees.UnaryN
   override def toString = s"AVG($child)"
 
   override def asPartial: SplitEvaluation = {
-    val partialSum = Alias(Sum(child), "PartialSum")()
-    val partialCount = Alias(Count(child), "PartialCount")()
-
     child.dataType match {
       case DecimalType.Fixed(_, _) =>
-        // Turn the results to unlimited decimals for the division, before 
going back to fixed
+        // Turn the child to unlimited decimals for calculation, before going 
back to fixed
+        val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), 
"PartialSum")()
+        val partialCount = Alias(Count(child), "PartialCount")()
+
         val castedSum = Cast(Sum(partialSum.toAttribute), 
DecimalType.Unlimited)
         val castedCount = Cast(Sum(partialCount.toAttribute), 
DecimalType.Unlimited)
         SplitEvaluation(
@@ -312,6 +312,9 @@ case class Average(child: Expression) extends 
PartialAggregate with trees.UnaryN
           partialCount :: partialSum :: Nil)
 
       case _ =>
+        val partialSum = Alias(Sum(child), "PartialSum")()
+        val partialCount = Alias(Count(child), "PartialCount")()
+
         val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
         val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
         SplitEvaluation(
@@ -325,7 +328,7 @@ case class Average(child: Expression) extends 
PartialAggregate with trees.UnaryN
 
 case class Sum(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
 
-  override def nullable = false
+  override def nullable = true
 
   override def dataType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
@@ -339,10 +342,19 @@ case class Sum(child: Expression) extends 
PartialAggregate with trees.UnaryNode[
   override def toString = s"SUM($child)"
 
   override def asPartial: SplitEvaluation = {
-    val partialSum = Alias(Sum(child), "PartialSum")()
-    SplitEvaluation(
-      Sum(partialSum.toAttribute),
-      partialSum :: Nil)
+    child.dataType match {
+      case DecimalType.Fixed(_, _) =>
+        val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), 
"PartialSum")()
+        SplitEvaluation(
+          Cast(Sum(partialSum.toAttribute), dataType),
+          partialSum :: Nil)
+
+      case _ =>
+        val partialSum = Alias(Sum(child), "PartialSum")()
+        SplitEvaluation(
+          Sum(partialSum.toAttribute),
+          partialSum :: Nil)
+    }
   }
 
   override def newInstance() = new SumFunction(child, this)
@@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate 
with trees.UnaryNode[
 case class SumDistinct(child: Expression)
   extends AggregateExpression with trees.UnaryNode[Expression] {
 
-  override def nullable = false
+  override def nullable = true
 
   override def dataType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
@@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: 
AggregateExpression)
 
   def this() = this(null, null) // Required for serialization.
 
-  private val zero = Cast(Literal(0), expr.dataType)
+  private val calcType =
+    expr.dataType match {
+      case DecimalType.Fixed(_, _) =>
+        DecimalType.Unlimited
+      case _ =>
+        expr.dataType
+    }
+
+  private val zero = Cast(Literal(0), calcType)
 
   private var count: Long = _
-  private val sum = MutableLiteral(zero.eval(null), expr.dataType)
-  private val sumAsDouble = Cast(sum, DoubleType)
+  private val sum = MutableLiteral(zero.eval(null), calcType)
 
-  private def addFunction(value: Any) = Add(sum, Literal(value))
+  private def addFunction(value: Any) = Add(sum, Cast(Literal(value, 
expr.dataType), calcType))
 
-  override def eval(input: Row): Any =
-    sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
+  override def eval(input: Row): Any = {
+    if (count == 0L) {
+      null
+    } else {
+      expr.dataType match {
+        case DecimalType.Fixed(_, _) =>
+          Cast(Divide(
+            Cast(sum, DecimalType.Unlimited),
+            Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
+        case _ =>
+          Divide(
+            Cast(sum, dataType),
+            Cast(Literal(count), dataType)).eval(null)
+      }
+    }
+  }
 
   override def update(input: Row): Unit = {
     val evaluatedExpr = expr.eval(input)
@@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction(
 case class SumFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
   def this() = this(null, null) // Required for serialization.
 
-  private val zero = Cast(Literal(0), expr.dataType)
+  private val calcType =
+    expr.dataType match {
+      case DecimalType.Fixed(_, _) =>
+        DecimalType.Unlimited
+      case _ =>
+        expr.dataType
+    }
+
+  private val zero = Cast(Literal(0), calcType)
 
-  private val sum = MutableLiteral(zero.eval(null), expr.dataType)
+  private val sum = MutableLiteral(null, calcType)
 
-  private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
+  private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), 
Cast(expr, calcType)), sum))
 
   override def update(input: Row): Unit = {
     sum.update(addFunction, input)
   }
 
-  override def eval(input: Row): Any = sum.eval(null)
+  override def eval(input: Row): Any = {
+    expr.dataType match {
+      case DecimalType.Fixed(_, _) =>
+        Cast(sum, dataType).eval(null)
+      case _ => sum.eval(null)
+    }
+  }
 }
 
 case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
@@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: 
AggregateExpression)
     }
   }
 
-  override def eval(input: Row): Any =
-    
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
+  override def eval(input: Row): Any = {
+    if (seen.size == 0) {
+      null
+    } else {
+      Cast(Literal(
+        seen.reduceLeft(
+          
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+        dataType).eval(null)
+    }
+  }
 }
 
 case class CountDistinctFunction(

http://git-wip-us.apache.org/repos/asf/spark/blob/1d7ee2b7/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 087b0ec..18afc5d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -83,29 +83,45 @@ case class GeneratedAggregate(
 
         AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, 
updateFunction :: Nil, result)
 
-      case Sum(expr) =>
-        val resultType = expr.dataType match {
-          case DecimalType.Fixed(precision, scale) =>
-            DecimalType(precision + 10, scale)
-          case _ =>
-            expr.dataType
-        }
+      case s @ Sum(expr) =>
+        val calcType =
+          expr.dataType match {
+            case DecimalType.Fixed(_, _) =>
+              DecimalType.Unlimited
+            case _ =>
+              expr.dataType
+          }
 
-        val currentSum = AttributeReference("currentSum", resultType, nullable 
= false)()
-        val initialValue = Cast(Literal(0L), resultType)
+        val currentSum = AttributeReference("currentSum", calcType, nullable = 
true)()
+        val initialValue = Literal(null, calcType)
 
         // Coalasce avoids double calculation...
         // but really, common sub expression elimination would be better....
-        val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: 
Nil)
-        val result = currentSum
+        val zero = Cast(Literal(0), calcType)
+        val updateFunction = Coalesce(
+          Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: 
currentSum :: Nil)
+        val result =
+          expr.dataType match {
+            case DecimalType.Fixed(_, _) =>
+              Cast(currentSum, s.dataType)
+            case _ => currentSum
+          }
 
         AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, 
updateFunction :: Nil, result)
 
       case a @ Average(expr) =>
+        val calcType =
+          expr.dataType match {
+            case DecimalType.Fixed(_, _) =>
+              DecimalType.Unlimited
+            case _ =>
+              expr.dataType
+          }
+
         val currentCount = AttributeReference("currentCount", LongType, 
nullable = false)()
-        val currentSum = AttributeReference("currentSum", expr.dataType, 
nullable = false)()
+        val currentSum = AttributeReference("currentSum", calcType, nullable = 
false)()
         val initialCount = Literal(0L)
-        val initialSum = Cast(Literal(0L), expr.dataType)
+        val initialSum = Cast(Literal(0L), calcType)
 
         // If we're evaluating UnscaledValue(x), we can do Count on x 
directly, since its
         // UnscaledValue will be null if and only if x is null; helps with 
Average on decimals
@@ -115,17 +131,21 @@ case class GeneratedAggregate(
         }
 
         val updateCount = If(IsNotNull(toCount), Add(currentCount, 
Literal(1L)), currentCount)
-        val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
-
-        val resultType = expr.dataType match {
-          case DecimalType.Fixed(precision, scale) =>
-            DecimalType(precision + 4, scale + 4)
-          case DecimalType.Unlimited =>
-            DecimalType.Unlimited
-          case _ =>
-            DoubleType
-        }
-        val result = Divide(Cast(currentSum, resultType), Cast(currentCount, 
resultType))
+        val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: 
currentSum :: Nil)
+
+        val result =
+          expr.dataType match {
+            case DecimalType.Fixed(_, _) =>
+              If(EqualTo(currentCount, Literal(0L)),
+                Literal(null, a.dataType),
+                Cast(Divide(
+                  Cast(currentSum, DecimalType.Unlimited),
+                  Cast(currentCount, DecimalType.Unlimited)), a.dataType))
+            case _ =>
+              If(EqualTo(currentCount, Literal(0L)),
+                Literal(null, a.dataType),
+                Divide(Cast(currentSum, a.dataType), Cast(currentCount, 
a.dataType)))
+          }
 
         AggregateEvaluation(
           currentCount :: currentSum :: Nil,

http://git-wip-us.apache.org/repos/asf/spark/blob/1d7ee2b7/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index e70ad89..94bd977 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -156,22 +156,58 @@ class DslQuerySuite extends QueryTest {
 
   test("average") {
     checkAnswer(
-      testData2.groupBy()(avg('a)),
+      testData2.aggregate(avg('a)),
       2.0)
+
+    checkAnswer(
+      testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
+      (2.0, 6.0) :: Nil)
+
+    checkAnswer(
+      decimalData.aggregate(avg('a)),
+      BigDecimal(2.0))
+    checkAnswer(
+      decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
+      (BigDecimal(2.0), BigDecimal(6)) :: Nil)
+
+    checkAnswer(
+      decimalData.aggregate(avg('a cast DecimalType(10, 2))),
+      BigDecimal(2.0))
+    checkAnswer(
+      decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a 
cast DecimalType(10, 2))), // non-partial
+      (BigDecimal(2.0), BigDecimal(6)) :: Nil)
   }
 
   test("null average") {
     checkAnswer(
-      testData3.groupBy()(avg('b)),
+      testData3.aggregate(avg('b)),
       2.0)
 
     checkAnswer(
-      testData3.groupBy()(avg('b), countDistinct('b)),
+      testData3.aggregate(avg('b), countDistinct('b)),
       (2.0, 1) :: Nil)
+
+    checkAnswer(
+      testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
+      (2.0, 2.0) :: Nil)
+  }
+
+  test("zero average") {
+    checkAnswer(
+      emptyTableData.aggregate(avg('a)),
+      null)
+
+    checkAnswer(
+      emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
+      (null, null) :: Nil)
   }
 
   test("count") {
     assert(testData2.count() === testData2.map(_ => 1).count())
+
+    checkAnswer(
+      testData2.aggregate(count('a), sumDistinct('a)), // non-partial
+      (6, 6.0) :: Nil)
   }
 
   test("null count") {
@@ -186,13 +222,34 @@ class DslQuerySuite extends QueryTest {
     )
 
     checkAnswer(
-      testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), 
countDistinct('b)),
+      testData3.aggregate(count('a), count('b), count(1), countDistinct('a), 
countDistinct('b)),
       (2, 1, 2, 2, 1) :: Nil
     )
+
+    checkAnswer(
+      testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // 
non-partial
+      (1, 1, 2) :: Nil
+    )
   }
 
   test("zero count") {
     assert(emptyTableData.count() === 0)
+
+    checkAnswer(
+      emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
+      (0, null) :: Nil)
+  }
+
+  test("zero sum") {
+    checkAnswer(
+      emptyTableData.aggregate(sum('a)),
+      null)
+  }
+
+  test("zero sum distinct") {
+    checkAnswer(
+      emptyTableData.aggregate(sumDistinct('a)),
+      null)
   }
 
   test("except") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1d7ee2b7/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 92b49e8..933e027 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -54,6 +54,17 @@ object TestData {
       TestData2(3, 2) :: Nil).toSchemaRDD
   testData2.registerTempTable("testData2")
 
+  case class DecimalData(a: BigDecimal, b: BigDecimal)
+  val decimalData =
+    TestSQLContext.sparkContext.parallelize(
+      DecimalData(1, 1) ::
+      DecimalData(1, 2) ::
+      DecimalData(2, 1) ::
+      DecimalData(2, 2) ::
+      DecimalData(3, 1) ::
+      DecimalData(3, 2) :: Nil).toSchemaRDD
+  decimalData.registerTempTable("decimalData")
+
   case class BinaryData(a: Array[Byte], b: Int)
   val binaryData =
     TestSQLContext.sparkContext.parallelize(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to