This is an automated email from the ASF dual-hosted git repository. dbtsai pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 921c22b [SPARK-26706][SQL] Fix `illegalNumericPrecedence` for ByteType 921c22b is described below commit 921c22b1fffc4844aa05c201ba15986be34a3782 Author: Anton Okolnychyi <aokolnyc...@apple.com> AuthorDate: Thu Jan 24 00:12:26 2019 +0000 [SPARK-26706][SQL] Fix `illegalNumericPrecedence` for ByteType This PR contains a minor change in `Cast$mayTruncate` that fixes its logic for bytes. Right now, `mayTruncate(ByteType, LongType)` returns `false` while `mayTruncate(ShortType, LongType)` returns `true`. Consequently, `spark.range(1, 3).as[Byte]` and `spark.range(1, 3).as[Short]` behave differently. Potentially, this bug can silently corrupt someone's data. ```scala // executes silently even though Long is converted into Byte spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte] .map(b => b - 1) .show() +-----+ |value| +-----+ | -12| | -11| | -10| | -9| | -8| | -7| | -6| | -5| | -4| | -3| +-----+ // throws an AnalysisException: Cannot up cast `id` from bigint to smallint as it may truncate spark.range(Long.MaxValue - 10, Long.MaxValue).as[Short] .map(s => s - 1) .show() ``` This PR comes with a set of unit tests. Closes #23632 from aokolnychyi/cast-fix. Authored-by: Anton Okolnychyi <aokolnyc...@apple.com> Signed-off-by: DB Tsai <d_t...@apple.com> --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../spark/sql/catalyst/expressions/CastSuite.scala | 36 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 9 ++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ee463bf..ac02dac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -131,7 +131,7 @@ object Cast { private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) - toPrecedence > 0 && fromPrecedence > toPrecedence + toPrecedence >= 0 && fromPrecedence > toPrecedence } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index d9f32c0..b1531ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -953,4 +954,39 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") } + + test("SPARK-26706: Fix Cast.mayTruncate for bytes") { + assert(!Cast.mayTruncate(ByteType, ByteType)) + assert(!Cast.mayTruncate(DecimalType.ByteDecimal, ByteType)) + assert(Cast.mayTruncate(ShortType, ByteType)) + assert(Cast.mayTruncate(IntegerType, ByteType)) + assert(Cast.mayTruncate(LongType, ByteType)) + assert(Cast.mayTruncate(FloatType, ByteType)) + assert(Cast.mayTruncate(DoubleType, ByteType)) + assert(Cast.mayTruncate(DecimalType.IntDecimal, ByteType)) + } + + test("canSafeCast and mayTruncate must be consistent for numeric types") { + import DataTypeTestUtils._ + + def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match { + case (_, dt: DecimalType) => dt.isWiderThan(from) + case (dt: DecimalType, _) => dt.isTighterThan(to) + case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to) + } + + numericTypes.foreach { from => + val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to)) + + safeTargetTypes.foreach { to => + assert(Cast.canSafeCast(from, to), s"It should be possible to safely cast $from to $to") + assert(!Cast.mayTruncate(from, to), s"No truncation is expected when casting $from to $to") + } + + unsafeTargetTypes.foreach { to => + assert(!Cast.canSafeCast(from, to), s"It shouldn't be possible to safely cast $from to $to") + assert(Cast.mayTruncate(from, to), s"Truncation is expected when casting $from to $to") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 50406bc..01d0877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1567,6 +1567,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c") checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null))) } + + test("SPARK-26706: Fix Cast.mayTruncate for bytes") { + val thrownException = intercept[AnalysisException] { + spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte] + .map(b => b - 1) + .collect() + } + assert(thrownException.message.contains("Cannot up cast `id` from bigint to tinyint")) + } } case class TestDataUnion(x: Int, y: Int, z: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org