This is an automated email from the ASF dual-hosted git repository. yangjie01 pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new 8fac87e8cbe9 [SPARK-52945][SQL][TESTS] Split `CastSuiteBase#checkInvalidCastFromNumericType` into three methods and guarantee assertions are valid 8fac87e8cbe9 is described below commit 8fac87e8cbe952c4965bbcaf363bf0af10c4dc18 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Mon Jul 28 20:20:59 2025 +0800 [SPARK-52945][SQL][TESTS] Split `CastSuiteBase#checkInvalidCastFromNumericType` into three methods and guarantee assertions are valid ### What changes were proposed in this pull request? Due to the absence of `assert` statements, the `CastSuiteBase#checkInvalidCastFromNumericType` method previously performed no assertion checks. Additionally, since `checkInvalidCastFromNumericType` had significant variations in target type validation logic across different `EvalMode` contexts, this pr refactors the method into three specialized methods to ensure robust assertion enforcement: - `checkInvalidCastFromNumericTypeToDateType` - `checkInvalidCastFromNumericTypeToTimestampNTZType` - `checkInvalidCastFromNumericTypeToBinaryType` ### Why are the changes needed? To address the missing assertion validation in `CastSuiteBase`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #51668 from LuciferYang/SPARK-52945. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: yangjie01 <yangji...@baidu.com> (cherry picked from commit 9a452f81dbddb765f55d0610e0e1691bd2ca6e96) Signed-off-by: yangjie01 <yangji...@baidu.com> --- .../sql/catalyst/expressions/CastSuiteBase.scala | 95 +++++++++------------- .../catalyst/expressions/CastWithAnsiOnSuite.scala | 32 +++++++- .../sql/catalyst/expressions/TryCastSuite.scala | 9 ++ 3 files changed, 77 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index cec49a5ae1de..869148165397 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -545,61 +545,42 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast("0", false) } - protected def checkInvalidCastFromNumericType(to: DataType): Unit = { - cast(1.toByte, to).checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "CAST_WITH_FUNC_SUGGESTION", - messageParameters = Map( - "srcType" -> toSQLType(Literal(1.toByte).dataType), - "targetType" -> toSQLType(to), - "functionNames" -> "`DATE_FROM_UNIX_DATE`" - ) - ) - cast(1.toShort, to).checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "CAST_WITH_FUNC_SUGGESTION", - messageParameters = Map( - "srcType" -> toSQLType(Literal(1.toShort).dataType), - "targetType" -> toSQLType(to), - "functionNames" -> "`DATE_FROM_UNIX_DATE`" - ) - ) - cast(1, to).checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "CAST_WITH_FUNC_SUGGESTION", - messageParameters = Map( - "srcType" -> toSQLType(Literal(1).dataType), - "targetType" -> toSQLType(to), - "functionNames" -> "`DATE_FROM_UNIX_DATE`" - ) - ) - cast(1L, to).checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "CAST_WITH_FUNC_SUGGESTION", - messageParameters = Map( - "srcType" -> toSQLType(Literal(1L).dataType), - "targetType" -> toSQLType(to), - "functionNames" -> "`DATE_FROM_UNIX_DATE`" - ) - ) - cast(1.0.toFloat, to).checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "CAST_WITH_FUNC_SUGGESTION", - messageParameters = Map( - "srcType" -> toSQLType(Literal(1.0.toFloat).dataType), - "targetType" -> toSQLType(to), - "functionNames" -> "`DATE_FROM_UNIX_DATE`" - ) - ) - cast(1.0, to).checkInputDataTypes() == - DataTypeMismatch( - errorSubClass = "CAST_WITH_FUNC_SUGGESTION", - messageParameters = Map( - "srcType" -> toSQLType(Literal(1.0).dataType), - "targetType" -> toSQLType(to), - "functionNames" -> "`DATE_FROM_UNIX_DATE`" - ) - ) + protected def createCastMismatch( + srcType: DataType, + targetType: DataType, + errorSubClass: String, + extraParams: Map[String, String] = Map.empty): DataTypeMismatch = { + val baseParams = Map( + "srcType" -> toSQLType(srcType), + "targetType" -> toSQLType(targetType) + ) + DataTypeMismatch(errorSubClass, baseParams ++ extraParams) + } + + protected def checkInvalidCastFromNumericTypeToDateType(): Unit = { + val errorSubClass = if (evalMode == EvalMode.LEGACY) { + "CAST_WITHOUT_SUGGESTION" + } else { + "CAST_WITH_FUNC_SUGGESTION" + } + val funcParams = if (evalMode == EvalMode.LEGACY) { + Map.empty[String, String] + } else { + Map("functionNames" -> "`DATE_FROM_UNIX_DATE`") + } + Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue => + val expectedError = + createCastMismatch(Literal(testValue).dataType, DateType, errorSubClass, funcParams) + assert(cast(testValue, DateType).checkInputDataTypes() == expectedError) + } + } + protected def checkInvalidCastFromNumericTypeToTimestampNTZType(): Unit = { + // All numeric types: `CAST_WITHOUT_SUGGESTION` + Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue => + val expectedError = + createCastMismatch(Literal(testValue).dataType, TimestampNTZType, "CAST_WITHOUT_SUGGESTION") + assert(cast(testValue, TimestampNTZType).checkInputDataTypes() == expectedError) + } } test("SPARK-16729 type checking for casting to date type") { @@ -614,7 +595,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { ) ) ) - checkInvalidCastFromNumericType(DateType) + checkInvalidCastFromNumericTypeToDateType() } test("SPARK-20302 cast with same structure") { @@ -998,7 +979,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { test("disallow type conversions between Numeric types and Timestamp without time zone type") { import DataTypeTestUtils.numericTypes - checkInvalidCastFromNumericType(TimestampNTZType) + checkInvalidCastFromNumericTypeToTimestampNTZType() verifyCastFailure( cast(Literal(0L), TimestampNTZType), DataTypeMismatch( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala index f0709db7259d..5bb726e09afe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} @@ -39,6 +40,33 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase { override def evalMode: EvalMode.Value = EvalMode.ANSI + protected def checkInvalidCastFromNumericTypeToBinaryType(): Unit = { + def checkNumericTypeCast( + testValue: Any, + srcType: DataType, + to: DataType, + expectedErrorClass: String, + extraParams: Map[String, String] = Map.empty): Unit = { + val expectedError = createCastMismatch(srcType, to, expectedErrorClass, extraParams) + assert(cast(testValue, to).checkInputDataTypes() == expectedError) + } + + // Integer types: suggest config change + val configParams = Map( + "config" -> toSQLConf(SQLConf.ANSI_ENABLED.key), + "configVal" -> toSQLValue("false", StringType) + ) + checkNumericTypeCast(1.toByte, ByteType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams) + checkNumericTypeCast( + 1.toShort, ShortType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams) + checkNumericTypeCast(1, IntegerType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams) + checkNumericTypeCast(1L, LongType, BinaryType, "CAST_WITH_CONF_SUGGESTION", configParams) + + // Floating types: no suggestion + checkNumericTypeCast(1.0.toFloat, FloatType, BinaryType, "CAST_WITHOUT_SUGGESTION") + checkNumericTypeCast(1.0, DoubleType, BinaryType, "CAST_WITHOUT_SUGGESTION") + } + private def isTryCast = evalMode == EvalMode.TRY private def testIntMaxAndMin(dt: DataType): Unit = { @@ -142,7 +170,7 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase { test("ANSI mode: disallow type conversions between Numeric types and Date type") { import DataTypeTestUtils.numericTypes - checkInvalidCastFromNumericType(DateType) + checkInvalidCastFromNumericTypeToDateType() verifyCastFailure( cast(Literal(0L), DateType), DataTypeMismatch( @@ -168,7 +196,7 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase { test("ANSI mode: disallow type conversions between Numeric types and Binary type") { import DataTypeTestUtils.numericTypes - checkInvalidCastFromNumericType(BinaryType) + checkInvalidCastFromNumericTypeToBinaryType() val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType) numericTypes.foreach { numericType => assert(cast(binaryLiteral, numericType).checkInputDataTypes() == diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index 446514de91d6..312b05755507 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -61,6 +61,15 @@ class TryCastSuite extends CastWithAnsiOnSuite { checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value)) } + override protected def checkInvalidCastFromNumericTypeToBinaryType(): Unit = { + // All numeric types: `CAST_WITHOUT_SUGGESTION` + Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue => + val expectedError = + createCastMismatch(Literal(testValue).dataType, BinaryType, "CAST_WITHOUT_SUGGESTION") + assert(cast(testValue, BinaryType).checkInputDataTypes() == expectedError) + } + } + test("print string") { assert(cast(Literal("1"), IntegerType).toString == "try_cast(1 as int)") assert(cast(Literal("1"), IntegerType).sql == "TRY_CAST('1' AS INT)") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org