This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 458b67c2c093 [SPARK-57170][SQL] Simplify Cast string to float/double 
codegen under ANSI mode
458b67c2c093 is described below

commit 458b67c2c0936bbea0176ab68d6544b82621deee
Author: Gengliang Wang <[email protected]>
AuthorDate: Mon Jun 1 09:45:19 2026 -0700

    [SPARK-57170][SQL] Simplify Cast string to float/double codegen under ANSI 
mode
    
    ### What changes were proposed in this pull request?
    
    Add `CastUtils.stringToFloatExact(UTF8String, QueryContext)` and 
`CastUtils.stringToDoubleExact(UTF8String, QueryContext)`, and route the ANSI 
(`ansiEnabled = true`) string -> float/double eval and codegen paths through 
them. Each helper parses the string and, on a `NumberFormatException`, falls 
back to `Cast.processFloatingPointSpecialLiterals` (inf / +inf / -inf / 
infinity / nan, case-insensitive); if that also yields no value it throws the 
ANSI `CAST_INVALID_INPUT` error citing the [...]
    
    `castToFloatCode` / `castToDoubleCode` previously emitted the same ~10-line 
`try { Float.valueOf(...) } catch (NumberFormatException) { ...special-literal 
fallback... }` block at both string call sites. They now dispatch on 
`ansiEnabled`: the ANSI branch emits a single helper call, while the non-ANSI 
branch keeps the inline `try/catch -> isNull` form (matching the pattern used 
by the already-merged Cast-to-decimal / `MakeDate` / `MakeInterval` cleanups). 
The eval paths delegate to the [...]
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). Collapsing the duplicated inline parse + 
special-literal fallback to a single helper call shrinks the generated Java for 
the common `CAST(string AS FLOAT/DOUBLE)`, helping with the JVM 64KB method / 
constant-pool limits, Janino compile time, and JIT work.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. The compiled behavior is identical; only the emitted Java source text 
changes.
    
    ### How was this patch tested?
    
    ```
    build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite 
*CastWithAnsiOffSuite *TryCastSuite"
    ```
    
    All pass (411 tests; exercised both with and without whole-stage codegen).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (Opus 4.8)
    
    Closes #56220 from gengliangwang/spark-cast-string-float-codegen.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit f9e75b2c5f0dc58a7b52d50481c88fd0051e7a4d)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../spark/sql/catalyst/expressions/CastUtils.java  |  34 +++++++
 .../spark/sql/catalyst/expressions/Cast.scala      | 102 ++++++++++-----------
 2 files changed, 82 insertions(+), 54 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
index a2e427b4a4ce..424a52e7d638 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
@@ -22,6 +22,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.types.UTF8String;
 
 /**
  * Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths)
@@ -112,4 +113,37 @@ public final class CastUtils {
   public static Decimal changePrecisionOrNull(Decimal d, int precision, int 
scale) {
     return d.changePrecision(precision, scale) ? d : null;
   }
+
+  // ----- string -> floating point (ANSI: throw on invalid input) -----
+  // Mirrors castToFloatCode / castToDoubleCode: parse the string, and on a
+  // NumberFormatException fall back to the special-literal forms handled by
+  // Cast.processFloatingPointSpecialLiterals (inf / +inf / -inf / infinity / 
nan,
+  // case-insensitive). If that also yields no value, throw the ANSI
+  // CAST_INVALID_INPUT error citing the original (untrimmed) input string.
+
+  public static float stringToFloatExact(UTF8String s, QueryContext context) {
+    String str = s.toString();
+    try {
+      return Float.parseFloat(str);
+    } catch (NumberFormatException e) {
+      Float f = (Float) Cast.processFloatingPointSpecialLiterals(str, true);
+      if (f == null) {
+        throw QueryExecutionErrors.invalidInputInCastToNumberError(FLOAT, s, 
context);
+      }
+      return f;
+    }
+  }
+
+  public static double stringToDoubleExact(UTF8String s, QueryContext context) 
{
+    String str = s.toString();
+    try {
+      return Double.parseDouble(str);
+    } catch (NumberFormatException e) {
+      Double d = (Double) Cast.processFloatingPointSpecialLiterals(str, false);
+      if (d == null) {
+        throw QueryExecutionErrors.invalidInputInCastToNumberError(DOUBLE, s, 
context);
+      }
+      return d;
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f190f8ca5055..8dcdb3b81128 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1168,16 +1168,14 @@ case class Cast(
   private[this] def castToDouble(from: DataType): Any => Any = from match {
     case _: StringType =>
       buildCast[UTF8String](_, s => {
-        val doubleStr = s.toString
-        try doubleStr.toDouble catch {
-          case _: NumberFormatException =>
-            val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
-            if (ansiEnabled && d == null) {
-              throw QueryExecutionErrors.invalidInputInCastToNumberError(
-                DoubleType, s, getContextOrNull())
-            } else {
-              d
-            }
+        if (ansiEnabled) {
+          CastUtils.stringToDoubleExact(s, getContextOrNull())
+        } else {
+          val doubleStr = s.toString
+          try doubleStr.toDouble catch {
+            case _: NumberFormatException =>
+              Cast.processFloatingPointSpecialLiterals(doubleStr, false)
+          }
         }
       })
     case BooleanType =>
@@ -1195,16 +1193,14 @@ case class Cast(
   private[this] def castToFloat(from: DataType): Any => Any = from match {
     case _: StringType =>
       buildCast[UTF8String](_, s => {
-        val floatStr = s.toString
-        try floatStr.toFloat catch {
-          case _: NumberFormatException =>
-            val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
-            if (ansiEnabled && f == null) {
-              throw QueryExecutionErrors.invalidInputInCastToNumberError(
-                FloatType, s, getContextOrNull())
-            } else {
-              f
-            }
+        if (ansiEnabled) {
+          CastUtils.stringToFloatExact(s, getContextOrNull())
+        } else {
+          val floatStr = s.toString
+          try floatStr.toFloat catch {
+            case _: NumberFormatException =>
+              Cast.processFloatingPointSpecialLiterals(floatStr, true)
+          }
         }
       })
     case BooleanType =>
@@ -2208,28 +2204,27 @@ case class Cast(
   private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): 
CastFunction = {
     from match {
       case _: StringType =>
-        val floatStr = ctx.freshVariable("floatStr", StringType)
         (c, evPrim, evNull) =>
-          val handleNull = if (ansiEnabled) {
+          if (ansiEnabled) {
+            val castUtils = classOf[CastUtils].getName
             val errorContext = getContextOrNullCode(ctx)
-            "throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
-              s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, 
$errorContext);"
+            code"$evPrim = $castUtils.stringToFloatExact($c, $errorContext);"
           } else {
-            s"$evNull = true;"
-          }
-          code"""
-          final String $floatStr = $c.toString();
-          try {
-            $evPrim = Float.valueOf($floatStr);
-          } catch (java.lang.NumberFormatException e) {
-            final Float f = (Float) 
Cast.processFloatingPointSpecialLiterals($floatStr, true);
-            if (f == null) {
-              $handleNull
-            } else {
-              $evPrim = f.floatValue();
+            val floatStr = ctx.freshVariable("floatStr", StringType)
+            code"""
+            final String $floatStr = $c.toString();
+            try {
+              $evPrim = Float.valueOf($floatStr);
+            } catch (java.lang.NumberFormatException e) {
+              final Float f = (Float) 
Cast.processFloatingPointSpecialLiterals($floatStr, true);
+              if (f == null) {
+                $evNull = true;
+              } else {
+                $evPrim = f.floatValue();
+              }
             }
+          """
           }
-        """
       case BooleanType =>
         (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
       case DateType =>
@@ -2246,28 +2241,27 @@ case class Cast(
   private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): 
CastFunction = {
     from match {
       case _: StringType =>
-        val doubleStr = ctx.freshVariable("doubleStr", StringType)
         (c, evPrim, evNull) =>
-          val handleNull = if (ansiEnabled) {
+          if (ansiEnabled) {
+            val castUtils = classOf[CastUtils].getName
             val errorContext = getContextOrNullCode(ctx)
-            "throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
-              s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, 
$errorContext);"
+            code"$evPrim = $castUtils.stringToDoubleExact($c, $errorContext);"
           } else {
-            s"$evNull = true;"
-          }
-          code"""
-          final String $doubleStr = $c.toString();
-          try {
-            $evPrim = Double.valueOf($doubleStr);
-          } catch (java.lang.NumberFormatException e) {
-            final Double d = (Double) 
Cast.processFloatingPointSpecialLiterals($doubleStr, false);
-            if (d == null) {
-              $handleNull
-            } else {
-              $evPrim = d.doubleValue();
+            val doubleStr = ctx.freshVariable("doubleStr", StringType)
+            code"""
+            final String $doubleStr = $c.toString();
+            try {
+              $evPrim = Double.valueOf($doubleStr);
+            } catch (java.lang.NumberFormatException e) {
+              final Double d = (Double) 
Cast.processFloatingPointSpecialLiterals($doubleStr, false);
+              if (d == null) {
+                $evNull = true;
+              } else {
+                $evPrim = d.doubleValue();
+              }
             }
+          """
           }
-        """
       case BooleanType =>
         (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
       case DateType =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to