This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b84dc909a885 [SPARK-50769][SQL] Fix ClassCastException in
HistogramNumeric
b84dc909a885 is described below
commit b84dc909a8856388faddc154c6a1d3aba271474e
Author: Linhong Liu <[email protected]>
AuthorDate: Thu Jan 9 12:08:16 2025 +0800
[SPARK-50769][SQL] Fix ClassCastException in HistogramNumeric
### What changes were proposed in this pull request?
The `HistogramNumeric` accepts `NumberType` but it doesn't properly handle
the `DecimalType` in the execution. Therefore, the `ClassCastException` when
trying to change a Decimal to Double.
### Why are the changes needed?
bug fix
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
```
build/sbt "sql/testOnly *SQLQueryTestSuite -- -z group-by.sql"
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49418 from linhongliu-db/SPARK-50769.
Authored-by: Linhong Liu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/aggregate/HistogramNumeric.scala | 10 +++++++++-
.../test/resources/sql-tests/analyzer-results/group-by.sql.out | 9 +++++++++
sql/core/src/test/resources/sql-tests/inputs/group-by.sql | 2 ++
sql/core/src/test/resources/sql-tests/results/group-by.sql.out | 9 +++++++++
4 files changed, 29 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
index eda2c742ab4b..142f4a4eae4c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
@@ -126,7 +126,10 @@ case class HistogramNumeric(
// Ignore empty rows, for example: histogram_numeric(null)
if (value != null) {
// Convert the value to a double value
- val doubleValue = value.asInstanceOf[Number].doubleValue
+ val doubleValue = value match {
+ case d: Decimal => d.toDouble
+ case o => o.asInstanceOf[Number].doubleValue()
+ }
buffer.add(doubleValue)
}
buffer
@@ -162,6 +165,11 @@ case class HistogramNumeric(
case ShortType => coord.x.toShort
case _: DayTimeIntervalType | LongType | TimestampType |
TimestampNTZType =>
coord.x.toLong
+ case d: DecimalType =>
+ val bigDecimal = BigDecimal
+ .decimal(coord.x, new java.math.MathContext(d.precision))
+ .setScale(d.scale, BigDecimal.RoundingMode.HALF_UP)
+ Decimal(bigDecimal)
case _ => coord.x
}
array(index) = InternalRow.apply(result, coord.y)
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
index 34ff2a2186f0..304b382c7bbe 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
@@ -1055,6 +1055,15 @@ Aggregate [histogram_numeric(col#xL, 3, 0, 0) AS
histogram_numeric(col, 3)#x]
+- LocalRelation [col#xL]
+-- !query
+SELECT histogram_numeric(col, 3) FROM VALUES
+ (CAST(1 AS DECIMAL(4, 2))), (CAST(2 AS DECIMAL(4, 2))), (CAST(3 AS
DECIMAL(4, 2))) AS tab(col)
+-- !query analysis
+Aggregate [histogram_numeric(col#x, 3, 0, 0) AS histogram_numeric(col, 3)#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
-- !query
SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'),
(TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS
tab(col)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index 6dd0adbc8722..0cc1f62b0583 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -221,6 +221,8 @@ SELECT histogram_numeric(col, 3) FROM VALUES
(CAST(1 AS SMALLINT)), (CAST(2 AS SMALLINT)), (CAST(3 AS SMALLINT)) AS
tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES
(CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)), (CAST(3 AS BIGINT)) AS tab(col);
+SELECT histogram_numeric(col, 3) FROM VALUES
+ (CAST(1 AS DECIMAL(4, 2))), (CAST(2 AS DECIMAL(4, 2))), (CAST(3 AS
DECIMAL(4, 2))) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'),
(TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS
tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '100-00' YEAR TO MONTH),
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index 633133ad7e4d..98ad1a0a5bba 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -987,6 +987,15 @@ struct<histogram_numeric(col,
3):array<struct<x:bigint,y:double>>>
[{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":3,"y":1.0}]
+-- !query
+SELECT histogram_numeric(col, 3) FROM VALUES
+ (CAST(1 AS DECIMAL(4, 2))), (CAST(2 AS DECIMAL(4, 2))), (CAST(3 AS
DECIMAL(4, 2))) AS tab(col)
+-- !query schema
+struct<histogram_numeric(col, 3):array<struct<x:decimal(4,2),y:double>>>
+-- !query output
+[{"x":1.00,"y":1.0},{"x":2.00,"y":1.0},{"x":3.00,"y":1.0}]
+
+
-- !query
SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'),
(TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS
tab(col)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]