This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 8899f89ccc85 [SPARK-56911][SQL] Simplify Cast to decimal codegen under 
ANSI mode
8899f89ccc85 is described below

commit 8899f89ccc85ea2d28c85a2cd31185e940f37cdd
Author: Gengliang Wang <[email protected]>
AuthorDate: Wed May 27 21:53:32 2026 -0700

    [SPARK-56911][SQL] Simplify Cast to decimal codegen under ANSI mode
    
    ### What changes were proposed in this pull request?
    
    Extend `CastUtils.java` with two helpers for decimal precision adjustment 
and use them from `Cast.changePrecision` (both eval and codegen). The new 
helpers mutate the input `Decimal` in place (matching the in-place semantics of 
the existing inline codegen), so they're safe to call on the temporary produced 
by `Decimal.fromString(...)` / `Decimal.apply(...)` / decimal-arithmetic 
results.
    
    Helpers added:
    * `changePrecisionExact(Decimal, int, int, QueryContext)`: ANSI throw on 
overflow, preserves the per-call-site `QueryContext` so error messages keep 
their query-origin info.
    * `changePrecisionOrNull(Decimal, int, int)`: non-ANSI, returns `null` on 
overflow (no `QueryContext` needed).
    
    `Cast.scala` changes:
    * `changePrecision` eval method dispatches on `nullOnOverflow` and 
delegates to the appropriate helper.
    * `changePrecision` codegen method has three branches now: the existing 
`canNullSafeCast` fast path (unchanged), a `nullOnOverflow` branch (inline), 
and the ANSI throw branch which now emits a one-line 
`CastUtils.changePrecisionExact(...)` call instead of the 5-line `if/else` 
overflow block.
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). The ANSI throw branch of 
`Cast.changePrecision` is hit by every cast to decimal that may overflow (very 
common in TPC-DS, where `cast(int as decimal(7,2))` is widespread). Collapsing 
the 5-line inline body to one line shrinks the generated Java source for those 
plans.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. The compiled behavior is identical; only the emitted Java source text 
changes.
    
    ### How was this patch tested?
    
    ```
    build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite 
*CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite *DecimalSuite"
    ```
    
    332/332 pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.x
    
    Closes #55936 from gengliangwang/SPARK-56911-cast-decimal.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 8af1b8b99d930eee53f552780cf6a15fafb7e343)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../spark/sql/catalyst/expressions/CastUtils.java  | 33 ++++++++++++++++------
 .../spark/sql/catalyst/expressions/Cast.scala      | 32 +++++++++------------
 2 files changed, 38 insertions(+), 27 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
index 700f7e41d233..a2e427b4a4ce 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
@@ -17,19 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
+import org.apache.spark.QueryContext;
 import org.apache.spark.sql.errors.QueryExecutionErrors;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Decimal;
 
 /**
- * Static helpers used by {@code Cast.doGenCode} (and corresponding eval
- * paths) for ANSI overflow-checked narrowing to {@code byte} / {@code short}.
- *
- * <p>Narrowing to {@code int} / {@code long} is handled by calling the 
existing
- * {@code LongExactNumeric} / {@code FloatExactNumeric} / {@code 
DoubleExactNumeric}
- * Scala objects directly from codegen (see SPARK-56909). The helpers below
- * cover {@code byte} / {@code short} only, since {@code ByteExactNumeric} /
- * {@code ShortExactNumeric} don't expose a cross-type narrowing API.
+ * Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths)
+ * for ANSI overflow-checked casts.
  *
  * <p>The source and target {@link DataType} objects referenced by the overflow
  * error message are held in {@code private static final} fields so the happy
@@ -47,6 +43,9 @@ public final class CastUtils {
   private static final DataType DOUBLE = DataTypes.DoubleType;
 
   // ----- integral narrowing (ANSI: throw on overflow) -----
+  // byte / short narrowing only; int / long narrowing is handled by calling 
the existing
+  // LongExactNumeric Scala object directly from codegen (see SPARK-56909). 
ByteExactNumeric /
+  // ShortExactNumeric don't expose a cross-type narrowing API, so a Java 
helper is the fit here.
 
   public static byte shortToByteExact(short v) {
     if (v == (byte) v) return (byte) v;
@@ -95,4 +94,22 @@ public final class CastUtils {
     if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) 
return (short) v;
     throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, SHORT);
   }
+
+  // ----- decimal precision adjustment -----
+  // Mutates the input Decimal in place to avoid the per-row clone() done by
+  // Decimal.toPrecision, since these helpers are called on the per-row hot 
path.
+  // On overflow, Decimal.changePrecision returns false before writing back 
any of
+  // decimalVal / longVal / _precision / _scale, so `d` is still in its 
original
+  // externally-visible state when changePrecisionExact throws -- the error 
message
+  // therefore cites the original (pre-cast) value.
+
+  public static Decimal changePrecisionExact(
+      Decimal d, int precision, int scale, QueryContext context) {
+    if (d.changePrecision(precision, scale)) return d;
+    throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(d, precision, 
scale, context);
+  }
+
+  public static Decimal changePrecisionOrNull(Decimal d, int precision, int 
scale) {
+    return d.changePrecision(precision, scale) ? d : null;
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0611c3e9bfb3..66501ebe7d5c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1097,15 +1097,11 @@ case class Cast(
       value: Decimal,
       decimalType: DecimalType,
       nullOnOverflow: Boolean): Decimal = {
-    if (value.changePrecision(decimalType.precision, decimalType.scale)) {
-      value
+    if (nullOnOverflow) {
+      CastUtils.changePrecisionOrNull(value, decimalType.precision, 
decimalType.scale)
     } else {
-      if (nullOnOverflow) {
-        null
-      } else {
-        throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
-          value, decimalType.precision, decimalType.scale, getContextOrNull())
-      }
+      CastUtils.changePrecisionExact(
+        value, decimalType.precision, decimalType.scale, getContextOrNull())
     }
   }
 
@@ -1558,23 +1554,21 @@ case class Cast(
          |$d.changePrecision(${decimalType.precision}, ${decimalType.scale});
          |$evPrim = $d;
        """.stripMargin
-    } else {
-      val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
-      val overflowCode = if (nullOnOverflow) {
-        s"$evNull = true;"
-      } else {
-        s"""
-           |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
-           |  $d, ${decimalType.precision}, ${decimalType.scale}, 
$errorContextCode);
-         """.stripMargin
-      }
+    } else if (nullOnOverflow) {
       code"""
          |if ($d.changePrecision(${decimalType.precision}, 
${decimalType.scale})) {
          |  $evPrim = $d;
          |} else {
-         |  $overflowCode
+         |  $evNull = true;
          |}
        """.stripMargin
+    } else {
+      val errorContextCode = getContextOrNullCode(ctx)
+      val castUtils = classOf[CastUtils].getName
+      code"""
+         |$evPrim = $castUtils.changePrecisionExact(
+         |  $d, ${decimalType.precision}, ${decimalType.scale}, 
$errorContextCode);
+       """.stripMargin
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to