Repository: spark Updated Branches: refs/heads/branch-1.0 f5399631c -> 50e234ba5
[SPARK-1915] [SQL] AverageFunction should not count if the evaluated value is null. Average values are difference between the calculation is done partially or not partially. Because `AverageFunction` (in not-partially calculation) counts even if the evaluated value is null. Author: Takuya UESHIN <[email protected]> Closes #862 from ueshin/issues/SPARK-1915 and squashes the following commits: b1ff3c0 [Takuya UESHIN] Modify AverageFunction not to count if the evaluated value is null. (cherry picked from commit 3b0babad1f0856ee16f9d58e1ead30779a4a6310) Signed-off-by: Reynold Xin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/50e234ba Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/50e234ba Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/50e234ba Branch: refs/heads/branch-1.0 Commit: 50e234ba510acac0f75c080b1b1ea681a3a28449 Parents: f539963 Author: Takuya UESHIN <[email protected]> Authored: Tue May 27 14:55:23 2014 -0700 Committer: Reynold Xin <[email protected]> Committed: Tue May 27 14:55:34 2014 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/aggregates.scala | 9 ++++++--- .../test/scala/org/apache/spark/sql/DslQuerySuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/50e234ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index b49a461..c902433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -281,14 +281,17 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private val sum = MutableLiteral(zero.eval(EmptyRow)) private val sumAsDouble = Cast(sum, DoubleType) - private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) + private def addFunction(value: Any) = Add(sum, Literal(value)) override def eval(input: Row): Any = sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble override def update(input: Row): Unit = { - count += 1 - sum.update(addFunction, input) + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + count += 1 + sum.update(addFunction(evaluatedExpr), input) + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/50e234ba/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 8197e8a..fb599e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -115,6 +115,16 @@ class DslQuerySuite extends QueryTest { 2.0) } + test("null average") { + checkAnswer( + testData3.groupBy()(Average('b)), + 2.0) + + checkAnswer( + testData3.groupBy()(Average('b), CountDistinct('b :: Nil)), + (2.0, 1) :: Nil) + } + test("count") { assert(testData2.count() === testData2.map(_ => 1).count()) }
