This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bc4a676 [SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
bc4a676 is described below
commit bc4a676b2752c691f7c1d824a58387dbfac6d695
Author: Marco Gaido <[email protected]>
AuthorDate: Mon Jul 1 11:54:58 2019 +0800
[SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
## What changes were proposed in this pull request?
In SPARK-23179, it has been introduced a flag to control the behavior in
case of overflow on decimals. The behavior is: returning `null` when
`spark.sql.decimalOperations.nullOnOverflow` (default and traditional Spark
behavior); throwing an `ArithmeticException` if that conf is false (according
to SQL standards, other DBs behavior).
`MakeDecimal` so far had an ambiguous behavior. In case of codegen mode, it
returned `null` as the other operators, but in interpreted mode, it was
throwing an `IllegalArgumentException`.
The PR aligns `MakeDecimal`'s behavior with the one of other operators as
defined in SPARK-23179. So now both modes return `null` or throw
`ArithmeticException` according to
`spark.sql.decimalOperations.nullOnOverflow`'s value.
Credits for this PR to mickjermsurawong-stripe who pointed out the wrong
behavior in #20350.
## How was this patch tested?
improved UTs
Closes #25010 from mgaido91/SPARK-28201.
Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/expressions/decimalExpressions.scala | 32 ++++++++++++++++++----
.../scala/org/apache/spark/sql/types/Decimal.scala | 9 +++---
.../expressions/DecimalExpressionSuite.scala | 20 ++++++++++++--
.../org/apache/spark/sql/types/DecimalSuite.scala | 10 +++----
4 files changed, 54 insertions(+), 17 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index ad7f7dd..b5b712c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
EmptyBlock, ExprCode}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
@@ -46,19 +47,38 @@ case class UnscaledValue(child: Expression) extends
UnaryExpression {
*/
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends
UnaryExpression {
+ private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
+
override def dataType: DataType = DecimalType(precision, scale)
- override def nullable: Boolean = true
+ override def nullable: Boolean = child.nullable || nullOnOverflow
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
- protected override def nullSafeEval(input: Any): Any =
- Decimal(input.asInstanceOf[Long], precision, scale)
+ protected override def nullSafeEval(input: Any): Any = {
+ val longInput = input.asInstanceOf[Long]
+ val result = new Decimal()
+ if (nullOnOverflow) {
+ result.setOrNull(longInput, precision, scale)
+ } else {
+ result.set(longInput, precision, scale)
+ }
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
+ val setMethod = if (nullOnOverflow) {
+ "setOrNull"
+ } else {
+ "set"
+ }
+ val setNull = if (nullable) {
+ s"${ev.isNull} = ${ev.value} == null;"
+ } else {
+ ""
+ }
s"""
- ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
- ${ev.isNull} = ${ev.value} == null;
- """
+ |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
+ |$setNull
+ |""".stripMargin
})
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index b7b7097..1bf322a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with
Serializable {
*/
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
if (setOrNull(unscaled, precision, scale) == null) {
- throw new IllegalArgumentException("Unscaled value too large for
precision")
+ throw new ArithmeticException("Unscaled value too large for precision")
}
this
}
@@ -111,9 +111,10 @@ final class Decimal extends Ordered[Decimal] with
Serializable {
*/
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
- require(
- decimalVal.precision <= precision,
- s"Decimal precision ${decimalVal.precision} exceeds max precision
$precision")
+ if (decimalVal.precision > precision) {
+ throw new ArithmeticException(
+ s"Decimal precision ${decimalVal.precision} exceeds max precision
$precision")
+ }
this.longVal = 0L
this._precision = precision
this._scale = scale
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index d14eceb..fc5e8dc 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}
class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -31,8 +32,23 @@ class DecimalExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
test("MakeDecimal") {
- checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
- checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+ withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
+ checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
+ checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+ val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
+ checkEvaluation(overflowExpr, null)
+ checkEvaluationWithMutableProjection(overflowExpr, null)
+ evaluateWithoutCodegen(overflowExpr, null)
+ checkEvaluationWithUnsafeProjection(overflowExpr, null)
+ }
+ withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
+ checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
+ checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+ val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
+
intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr,
null))
+ intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr,
null))
+
intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr,
null))
+ }
}
test("PromotePrecision") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 8abd762..d69bb2f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -56,11 +56,11 @@ class DecimalSuite extends SparkFunSuite with
PrivateMethodTester {
checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00",
20, 2)
checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
- intercept[IllegalArgumentException](Decimal(170L, 2, 1))
- intercept[IllegalArgumentException](Decimal(170L, 2, 0))
- intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
- intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
- intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
+ intercept[ArithmeticException](Decimal(170L, 2, 1))
+ intercept[ArithmeticException](Decimal(170L, 2, 0))
+ intercept[ArithmeticException](Decimal(BigDecimal("10.030"), 2, 1))
+ intercept[ArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1))
+ intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0))
}
test("creating decimals with negative scale") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]