This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 8fac87e8cbe9 [SPARK-52945][SQL][TESTS] Split 
`CastSuiteBase#checkInvalidCastFromNumericType` into three methods and 
guarantee assertions are valid
8fac87e8cbe9 is described below

commit 8fac87e8cbe952c4965bbcaf363bf0af10c4dc18
Author: yangjie01 <yangji...@baidu.com>
AuthorDate: Mon Jul 28 20:20:59 2025 +0800

    [SPARK-52945][SQL][TESTS] Split 
`CastSuiteBase#checkInvalidCastFromNumericType` into three methods and 
guarantee assertions are valid
    
    ### What changes were proposed in this pull request?
    Due to the absence of `assert` statements, the 
`CastSuiteBase#checkInvalidCastFromNumericType` method previously performed no 
assertion checks.
    
    Additionally, since `checkInvalidCastFromNumericType` had significant 
variations in target type validation logic across different `EvalMode` 
contexts, this pr refactors the method into three specialized methods to ensure 
robust assertion enforcement:
    
    - `checkInvalidCastFromNumericTypeToDateType`
    - `checkInvalidCastFromNumericTypeToTimestampNTZType`
    - `checkInvalidCastFromNumericTypeToBinaryType`
    
    ### Why are the changes needed?
    To address the missing assertion validation in `CastSuiteBase`.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    - Pass GitHub Actions
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #51668 from LuciferYang/SPARK-52945.
    
    Authored-by: yangjie01 <yangji...@baidu.com>
    Signed-off-by: yangjie01 <yangji...@baidu.com>
    (cherry picked from commit 9a452f81dbddb765f55d0610e0e1691bd2ca6e96)
    Signed-off-by: yangjie01 <yangji...@baidu.com>
---
 .../sql/catalyst/expressions/CastSuiteBase.scala   | 95 +++++++++-------------
 .../catalyst/expressions/CastWithAnsiOnSuite.scala | 32 +++++++-
 .../sql/catalyst/expressions/TryCastSuite.scala    |  9 ++
 3 files changed, 77 insertions(+), 59 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index cec49a5ae1de..869148165397 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -545,61 +545,42 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
     checkCast("0", false)
   }
 
-  protected def checkInvalidCastFromNumericType(to: DataType): Unit = {
-    cast(1.toByte, to).checkInputDataTypes() ==
-      DataTypeMismatch(
-        errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
-        messageParameters = Map(
-          "srcType" -> toSQLType(Literal(1.toByte).dataType),
-          "targetType" -> toSQLType(to),
-          "functionNames" -> "`DATE_FROM_UNIX_DATE`"
-        )
-      )
-    cast(1.toShort, to).checkInputDataTypes() ==
-      DataTypeMismatch(
-        errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
-        messageParameters = Map(
-          "srcType" -> toSQLType(Literal(1.toShort).dataType),
-          "targetType" -> toSQLType(to),
-          "functionNames" -> "`DATE_FROM_UNIX_DATE`"
-        )
-      )
-    cast(1, to).checkInputDataTypes() ==
-      DataTypeMismatch(
-        errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
-        messageParameters = Map(
-          "srcType" -> toSQLType(Literal(1).dataType),
-          "targetType" -> toSQLType(to),
-          "functionNames" -> "`DATE_FROM_UNIX_DATE`"
-        )
-      )
-    cast(1L, to).checkInputDataTypes() ==
-      DataTypeMismatch(
-        errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
-        messageParameters = Map(
-          "srcType" -> toSQLType(Literal(1L).dataType),
-          "targetType" -> toSQLType(to),
-          "functionNames" -> "`DATE_FROM_UNIX_DATE`"
-        )
-      )
-    cast(1.0.toFloat, to).checkInputDataTypes() ==
-      DataTypeMismatch(
-        errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
-        messageParameters = Map(
-          "srcType" -> toSQLType(Literal(1.0.toFloat).dataType),
-          "targetType" -> toSQLType(to),
-          "functionNames" -> "`DATE_FROM_UNIX_DATE`"
-        )
-      )
-    cast(1.0, to).checkInputDataTypes() ==
-      DataTypeMismatch(
-        errorSubClass = "CAST_WITH_FUNC_SUGGESTION",
-        messageParameters = Map(
-          "srcType" -> toSQLType(Literal(1.0).dataType),
-          "targetType" -> toSQLType(to),
-          "functionNames" -> "`DATE_FROM_UNIX_DATE`"
-        )
-      )
+  protected def createCastMismatch(
+      srcType: DataType,
+      targetType: DataType,
+      errorSubClass: String,
+      extraParams: Map[String, String] = Map.empty): DataTypeMismatch = {
+    val baseParams = Map(
+      "srcType" -> toSQLType(srcType),
+      "targetType" -> toSQLType(targetType)
+    )
+    DataTypeMismatch(errorSubClass, baseParams ++ extraParams)
+  }
+
+  protected def checkInvalidCastFromNumericTypeToDateType(): Unit = {
+    val errorSubClass = if (evalMode == EvalMode.LEGACY) {
+      "CAST_WITHOUT_SUGGESTION"
+    } else {
+      "CAST_WITH_FUNC_SUGGESTION"
+    }
+    val funcParams = if (evalMode == EvalMode.LEGACY) {
+      Map.empty[String, String]
+    } else {
+      Map("functionNames" -> "`DATE_FROM_UNIX_DATE`")
+    }
+    Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue =>
+      val expectedError =
+        createCastMismatch(Literal(testValue).dataType, DateType, 
errorSubClass, funcParams)
+      assert(cast(testValue, DateType).checkInputDataTypes() == expectedError)
+    }
+  }
+  protected def checkInvalidCastFromNumericTypeToTimestampNTZType(): Unit = {
+    // All numeric types: `CAST_WITHOUT_SUGGESTION`
+    Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue =>
+      val expectedError =
+        createCastMismatch(Literal(testValue).dataType, TimestampNTZType, 
"CAST_WITHOUT_SUGGESTION")
+      assert(cast(testValue, TimestampNTZType).checkInputDataTypes() == 
expectedError)
+    }
   }
 
   test("SPARK-16729 type checking for casting to date type") {
@@ -614,7 +595,7 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
         )
       )
     )
-    checkInvalidCastFromNumericType(DateType)
+    checkInvalidCastFromNumericTypeToDateType()
   }
 
   test("SPARK-20302 cast with same structure") {
@@ -998,7 +979,7 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
 
   test("disallow type conversions between Numeric types and Timestamp without 
time zone type") {
     import DataTypeTestUtils.numericTypes
-    checkInvalidCastFromNumericType(TimestampNTZType)
+    checkInvalidCastFromNumericTypeToTimestampNTZType()
     verifyCastFailure(
       cast(Literal(0L), TimestampNTZType),
       DataTypeMismatch(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
index f0709db7259d..5bb726e09afe 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
 import 
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
 import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 
@@ -39,6 +40,33 @@ class CastWithAnsiOnSuite extends CastSuiteBase with 
QueryErrorsBase {
 
   override def evalMode: EvalMode.Value = EvalMode.ANSI
 
+  protected def checkInvalidCastFromNumericTypeToBinaryType(): Unit = {
+    def checkNumericTypeCast(
+        testValue: Any,
+        srcType: DataType,
+        to: DataType,
+        expectedErrorClass: String,
+        extraParams: Map[String, String] = Map.empty): Unit = {
+      val expectedError = createCastMismatch(srcType, to, expectedErrorClass, 
extraParams)
+      assert(cast(testValue, to).checkInputDataTypes() == expectedError)
+    }
+
+    // Integer types: suggest config change
+    val configParams = Map(
+      "config" -> toSQLConf(SQLConf.ANSI_ENABLED.key),
+      "configVal" -> toSQLValue("false", StringType)
+    )
+    checkNumericTypeCast(1.toByte, ByteType, BinaryType, 
"CAST_WITH_CONF_SUGGESTION", configParams)
+    checkNumericTypeCast(
+      1.toShort, ShortType, BinaryType, "CAST_WITH_CONF_SUGGESTION", 
configParams)
+    checkNumericTypeCast(1, IntegerType, BinaryType, 
"CAST_WITH_CONF_SUGGESTION", configParams)
+    checkNumericTypeCast(1L, LongType, BinaryType, 
"CAST_WITH_CONF_SUGGESTION", configParams)
+
+    // Floating types: no suggestion
+    checkNumericTypeCast(1.0.toFloat, FloatType, BinaryType, 
"CAST_WITHOUT_SUGGESTION")
+    checkNumericTypeCast(1.0, DoubleType, BinaryType, 
"CAST_WITHOUT_SUGGESTION")
+  }
+
   private def isTryCast = evalMode == EvalMode.TRY
 
   private def testIntMaxAndMin(dt: DataType): Unit = {
@@ -142,7 +170,7 @@ class CastWithAnsiOnSuite extends CastSuiteBase with 
QueryErrorsBase {
 
   test("ANSI mode: disallow type conversions between Numeric types and Date 
type") {
     import DataTypeTestUtils.numericTypes
-    checkInvalidCastFromNumericType(DateType)
+    checkInvalidCastFromNumericTypeToDateType()
     verifyCastFailure(
       cast(Literal(0L), DateType),
       DataTypeMismatch(
@@ -168,7 +196,7 @@ class CastWithAnsiOnSuite extends CastSuiteBase with 
QueryErrorsBase {
 
   test("ANSI mode: disallow type conversions between Numeric types and Binary 
type") {
     import DataTypeTestUtils.numericTypes
-    checkInvalidCastFromNumericType(BinaryType)
+    checkInvalidCastFromNumericTypeToBinaryType()
     val binaryLiteral = Literal(new Array[Byte](1.toByte), BinaryType)
     numericTypes.foreach { numericType =>
       assert(cast(binaryLiteral, numericType).checkInputDataTypes() ==
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
index 446514de91d6..312b05755507 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
@@ -61,6 +61,15 @@ class TryCastSuite extends CastWithAnsiOnSuite {
     checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
   }
 
+  override protected def checkInvalidCastFromNumericTypeToBinaryType(): Unit = 
{
+    // All numeric types: `CAST_WITHOUT_SUGGESTION`
+    Seq(1.toByte, 1.toShort, 1, 1L, 1.0.toFloat, 1.0).foreach { testValue =>
+      val expectedError =
+        createCastMismatch(Literal(testValue).dataType, BinaryType, 
"CAST_WITHOUT_SUGGESTION")
+      assert(cast(testValue, BinaryType).checkInputDataTypes() == 
expectedError)
+    }
+  }
+
   test("print string") {
     assert(cast(Literal("1"), IntegerType).toString == "try_cast(1 as int)")
     assert(cast(Literal("1"), IntegerType).sql == "TRY_CAST('1' AS INT)")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to