Repository: spark Updated Branches: refs/heads/branch-2.1 43084b3cc -> 66a7ca28a
[SPARK-19691][SQL][BRANCH-2.1] Fix ClassCastException when calculating percentile of decimal column ## What changes were proposed in this pull request? This is a backport of the two following commits: https://github.com/apache/spark/commit/93aa4271596a30752dc5234d869c3ae2f6e8e723 This pr fixed a class-cast exception below; ``` scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect() java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141) at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58) at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109) at ``` This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`. ## How was this patch tested? Added a test in `DataFrameSuite`. Author: Takeshi Yamamuro <[email protected]> Closes #17046 from maropu/SPARK-19691-BACKPORT2.1. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/66a7ca28 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/66a7ca28 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/66a7ca28 Branch: refs/heads/branch-2.1 Commit: 66a7ca28a9de92e67ce24896a851a0c96c92aec6 Parents: 43084b3 Author: Takeshi Yamamuro <[email protected]> Authored: Fri Feb 24 10:54:00 2017 +0100 Committer: Herman van Hovell <[email protected]> Committed: Fri Feb 24 10:54:00 2017 +0100 ---------------------------------------------------------------------- .../expressions/aggregate/Percentile.scala | 42 +++++++++++--------- .../expressions/aggregate/PercentileSuite.scala | 6 +-- .../org/apache/spark/sql/DataFrameSuite.scala | 5 +++ 3 files changed, 31 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 356e088..8dd4f2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -57,7 +57,7 @@ case class Percentile( child: Expression, percentageExpression: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] { + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, 0, 0) @@ -123,13 +123,18 @@ case class Percentile( } } - override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + private def toDoubleValue(d: Any): Double = d match { + case d: Decimal => d.toDouble + case n: Number => n.doubleValue + } + + override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { // Initialize new counts map instance here. - new OpenHashMap[Number, Long]() + new OpenHashMap[AnyRef, Long]() } - override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = { - val key = child.eval(input).asInstanceOf[Number] + override def update(buffer: OpenHashMap[AnyRef, Long], input: InternalRow): Unit = { + val key = child.eval(input).asInstanceOf[AnyRef] // Null values are ignored in counts map. if (key != null) { @@ -137,30 +142,30 @@ case class Percentile( } } - override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = { + override def merge(buffer: OpenHashMap[AnyRef, Long], other: OpenHashMap[AnyRef, Long]): Unit = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } } - override def eval(buffer: OpenHashMap[Number, Long]): Any = { + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { generateOutput(getPercentiles(buffer)) } - private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { if (buffer.isEmpty) { return Seq.empty } val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile) } } @@ -180,7 +185,7 @@ case class Percentile( * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -193,18 +198,17 @@ case class Percentile( val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction - return lowerKey + return toDoubleValue(lowerKey) } val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key - return lowerKey + return toDoubleValue(lowerKey) } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() + (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) } /** @@ -218,7 +222,7 @@ case class Percentile( } } - override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { val buffer = new Array[Byte](4 << 10) // 4K val bos = new ByteArrayOutputStream() val out = new DataOutputStream(bos) @@ -241,11 +245,11 @@ case class Percentile( } } - override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(bis) try { - val counts = new OpenHashMap[Number, Long] + val counts = new OpenHashMap[AnyRef, Long] // Read unsafeRow size and content in bytes. var sizeOfNextRow = ins.readInt() while (sizeOfNextRow >= 0) { @@ -254,7 +258,7 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val key = row.get(0, child.dataType).asInstanceOf[Number] + val key = row.get(0, child.dataType) val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index f060ecc..d7c2527 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -38,12 +38,12 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) // Check empty serialize and deserialize - val buffer = new OpenHashMap[Number, Long]() + val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) // Check non-empty buffer serializa and deserialize. data.foreach { key => - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(new Integer(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -233,7 +233,7 @@ class PercentileSuite extends SparkFunSuite { } private def compareEquals( - left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = { left.size == right.size && left.forall { case (key, count) => right.apply(key) == count } http://git-wip-us.apache.org/repos/asf/spark/blob/66a7ca28/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 312cd17..22dfc46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1734,4 +1734,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) assert(df.filter($"array1" === $"array2").count() == 1) } + + test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { + val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") + checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
