This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 458b67c2c093 [SPARK-57170][SQL] Simplify Cast string to float/double
codegen under ANSI mode
458b67c2c093 is described below
commit 458b67c2c0936bbea0176ab68d6544b82621deee
Author: Gengliang Wang <[email protected]>
AuthorDate: Mon Jun 1 09:45:19 2026 -0700
[SPARK-57170][SQL] Simplify Cast string to float/double codegen under ANSI
mode
### What changes were proposed in this pull request?
Add `CastUtils.stringToFloatExact(UTF8String, QueryContext)` and
`CastUtils.stringToDoubleExact(UTF8String, QueryContext)`, and route the ANSI
(`ansiEnabled = true`) string -> float/double eval and codegen paths through
them. Each helper parses the string and, on a `NumberFormatException`, falls
back to `Cast.processFloatingPointSpecialLiterals` (inf / +inf / -inf /
infinity / nan, case-insensitive); if that also yields no value it throws the
ANSI `CAST_INVALID_INPUT` error citing the [...]
`castToFloatCode` / `castToDoubleCode` previously emitted the same ~10-line
`try { Float.valueOf(...) } catch (NumberFormatException) { ...special-literal
fallback... }` block at both string call sites. They now dispatch on
`ansiEnabled`: the ANSI branch emits a single helper call, while the non-ANSI
branch keeps the inline `try/catch -> isNull` form (matching the pattern used
by the already-merged Cast-to-decimal / `MakeDate` / `MakeInterval` cleanups).
The eval paths delegate to the [...]
### Why are the changes needed?
Part of SPARK-56908 (umbrella). Collapsing the duplicated inline parse +
special-literal fallback to a single helper call shrinks the generated Java for
the common `CAST(string AS FLOAT/DOUBLE)`, helping with the JVM 64KB method /
constant-pool limits, Janino compile time, and JIT work.
### Does this PR introduce _any_ user-facing change?
No. The compiled behavior is identical; only the emitted Java source text
changes.
### How was this patch tested?
```
build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite
*CastWithAnsiOffSuite *TryCastSuite"
```
All pass (411 tests; exercised both with and without whole-stage codegen).
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Opus 4.8)
Closes #56220 from gengliangwang/spark-cast-string-float-codegen.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit f9e75b2c5f0dc58a7b52d50481c88fd0051e7a4d)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../spark/sql/catalyst/expressions/CastUtils.java | 34 +++++++
.../spark/sql/catalyst/expressions/Cast.scala | 102 ++++++++++-----------
2 files changed, 82 insertions(+), 54 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
index a2e427b4a4ce..424a52e7d638 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
@@ -22,6 +22,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.types.UTF8String;
/**
* Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths)
@@ -112,4 +113,37 @@ public final class CastUtils {
public static Decimal changePrecisionOrNull(Decimal d, int precision, int
scale) {
return d.changePrecision(precision, scale) ? d : null;
}
+
+ // ----- string -> floating point (ANSI: throw on invalid input) -----
+ // Mirrors castToFloatCode / castToDoubleCode: parse the string, and on a
+ // NumberFormatException fall back to the special-literal forms handled by
+ // Cast.processFloatingPointSpecialLiterals (inf / +inf / -inf / infinity /
nan,
+ // case-insensitive). If that also yields no value, throw the ANSI
+ // CAST_INVALID_INPUT error citing the original (untrimmed) input string.
+
+ public static float stringToFloatExact(UTF8String s, QueryContext context) {
+ String str = s.toString();
+ try {
+ return Float.parseFloat(str);
+ } catch (NumberFormatException e) {
+ Float f = (Float) Cast.processFloatingPointSpecialLiterals(str, true);
+ if (f == null) {
+ throw QueryExecutionErrors.invalidInputInCastToNumberError(FLOAT, s,
context);
+ }
+ return f;
+ }
+ }
+
+ public static double stringToDoubleExact(UTF8String s, QueryContext context)
{
+ String str = s.toString();
+ try {
+ return Double.parseDouble(str);
+ } catch (NumberFormatException e) {
+ Double d = (Double) Cast.processFloatingPointSpecialLiterals(str, false);
+ if (d == null) {
+ throw QueryExecutionErrors.invalidInputInCastToNumberError(DOUBLE, s,
context);
+ }
+ return d;
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f190f8ca5055..8dcdb3b81128 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1168,16 +1168,14 @@ case class Cast(
private[this] def castToDouble(from: DataType): Any => Any = from match {
case _: StringType =>
buildCast[UTF8String](_, s => {
- val doubleStr = s.toString
- try doubleStr.toDouble catch {
- case _: NumberFormatException =>
- val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
- if (ansiEnabled && d == null) {
- throw QueryExecutionErrors.invalidInputInCastToNumberError(
- DoubleType, s, getContextOrNull())
- } else {
- d
- }
+ if (ansiEnabled) {
+ CastUtils.stringToDoubleExact(s, getContextOrNull())
+ } else {
+ val doubleStr = s.toString
+ try doubleStr.toDouble catch {
+ case _: NumberFormatException =>
+ Cast.processFloatingPointSpecialLiterals(doubleStr, false)
+ }
}
})
case BooleanType =>
@@ -1195,16 +1193,14 @@ case class Cast(
private[this] def castToFloat(from: DataType): Any => Any = from match {
case _: StringType =>
buildCast[UTF8String](_, s => {
- val floatStr = s.toString
- try floatStr.toFloat catch {
- case _: NumberFormatException =>
- val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
- if (ansiEnabled && f == null) {
- throw QueryExecutionErrors.invalidInputInCastToNumberError(
- FloatType, s, getContextOrNull())
- } else {
- f
- }
+ if (ansiEnabled) {
+ CastUtils.stringToFloatExact(s, getContextOrNull())
+ } else {
+ val floatStr = s.toString
+ try floatStr.toFloat catch {
+ case _: NumberFormatException =>
+ Cast.processFloatingPointSpecialLiterals(floatStr, true)
+ }
}
})
case BooleanType =>
@@ -2208,28 +2204,27 @@ case class Cast(
private[this] def castToFloatCode(from: DataType, ctx: CodegenContext):
CastFunction = {
from match {
case _: StringType =>
- val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
- val handleNull = if (ansiEnabled) {
+ if (ansiEnabled) {
+ val castUtils = classOf[CastUtils].getName
val errorContext = getContextOrNullCode(ctx)
- "throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
- s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c,
$errorContext);"
+ code"$evPrim = $castUtils.stringToFloatExact($c, $errorContext);"
} else {
- s"$evNull = true;"
- }
- code"""
- final String $floatStr = $c.toString();
- try {
- $evPrim = Float.valueOf($floatStr);
- } catch (java.lang.NumberFormatException e) {
- final Float f = (Float)
Cast.processFloatingPointSpecialLiterals($floatStr, true);
- if (f == null) {
- $handleNull
- } else {
- $evPrim = f.floatValue();
+ val floatStr = ctx.freshVariable("floatStr", StringType)
+ code"""
+ final String $floatStr = $c.toString();
+ try {
+ $evPrim = Float.valueOf($floatStr);
+ } catch (java.lang.NumberFormatException e) {
+ final Float f = (Float)
Cast.processFloatingPointSpecialLiterals($floatStr, true);
+ if (f == null) {
+ $evNull = true;
+ } else {
+ $evPrim = f.floatValue();
+ }
}
+ """
}
- """
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
@@ -2246,28 +2241,27 @@ case class Cast(
private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext):
CastFunction = {
from match {
case _: StringType =>
- val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
- val handleNull = if (ansiEnabled) {
+ if (ansiEnabled) {
+ val castUtils = classOf[CastUtils].getName
val errorContext = getContextOrNullCode(ctx)
- "throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
- s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c,
$errorContext);"
+ code"$evPrim = $castUtils.stringToDoubleExact($c, $errorContext);"
} else {
- s"$evNull = true;"
- }
- code"""
- final String $doubleStr = $c.toString();
- try {
- $evPrim = Double.valueOf($doubleStr);
- } catch (java.lang.NumberFormatException e) {
- final Double d = (Double)
Cast.processFloatingPointSpecialLiterals($doubleStr, false);
- if (d == null) {
- $handleNull
- } else {
- $evPrim = d.doubleValue();
+ val doubleStr = ctx.freshVariable("doubleStr", StringType)
+ code"""
+ final String $doubleStr = $c.toString();
+ try {
+ $evPrim = Double.valueOf($doubleStr);
+ } catch (java.lang.NumberFormatException e) {
+ final Double d = (Double)
Cast.processFloatingPointSpecialLiterals($doubleStr, false);
+ if (d == null) {
+ $evNull = true;
+ } else {
+ $evPrim = d.doubleValue();
+ }
}
+ """
}
- """
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]