Repository: spark
Updated Branches:
  refs/heads/branch-1.0 f5399631c -> 50e234ba5


[SPARK-1915] [SQL] AverageFunction should not count if the evaluated value is 
null.

Average values are difference between the calculation is done partially or not 
partially.
Because `AverageFunction` (in not-partially calculation) counts even if the 
evaluated value is null.

Author: Takuya UESHIN <[email protected]>

Closes #862 from ueshin/issues/SPARK-1915 and squashes the following commits:

b1ff3c0 [Takuya UESHIN] Modify AverageFunction not to count if the evaluated 
value is null.

(cherry picked from commit 3b0babad1f0856ee16f9d58e1ead30779a4a6310)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/50e234ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/50e234ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/50e234ba

Branch: refs/heads/branch-1.0
Commit: 50e234ba510acac0f75c080b1b1ea681a3a28449
Parents: f539963
Author: Takuya UESHIN <[email protected]>
Authored: Tue May 27 14:55:23 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Tue May 27 14:55:34 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/aggregates.scala       |  9 ++++++---
 .../test/scala/org/apache/spark/sql/DslQuerySuite.scala   | 10 ++++++++++
 2 files changed, 16 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/50e234ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index b49a461..c902433 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -281,14 +281,17 @@ case class AverageFunction(expr: Expression, base: 
AggregateExpression)
   private val sum = MutableLiteral(zero.eval(EmptyRow))
   private val sumAsDouble = Cast(sum, DoubleType)
 
-  private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
+  private def addFunction(value: Any) = Add(sum, Literal(value))
 
   override def eval(input: Row): Any =
     sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
 
   override def update(input: Row): Unit = {
-    count += 1
-    sum.update(addFunction, input)
+    val evaluatedExpr = expr.eval(input)
+    if (evaluatedExpr != null) {
+      count += 1
+      sum.update(addFunction(evaluatedExpr), input)
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/50e234ba/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 8197e8a..fb599e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -115,6 +115,16 @@ class DslQuerySuite extends QueryTest {
       2.0)
   }
 
+  test("null average") {
+    checkAnswer(
+      testData3.groupBy()(Average('b)),
+      2.0)
+
+    checkAnswer(
+      testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
+      (2.0, 1) :: Nil)
+  }
+
   test("count") {
     assert(testData2.count() === testData2.map(_ => 1).count())
   }

Reply via email to