This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6165bb0c99f2 [SPARK-56910][SQL] Simplify Cast to byte/short codegen 
under ANSI mode
6165bb0c99f2 is described below

commit 6165bb0c99f219bc49586548abe61e683857f546
Author: Gengliang Wang <[email protected]>
AuthorDate: Wed May 27 11:30:45 2026 -0700

    [SPARK-56910][SQL] Simplify Cast to byte/short codegen under ANSI mode
    
    ### What changes were proposed in this pull request?
    
    Introduce `CastUtils.java` with nine static helpers for ANSI 
overflow-checked narrowing to `byte` / `short`, and use them from `Cast.scala` 
(both codegen and eval paths).
    
    Helpers added:
    * `shortToByteExact(short)`, `intToByteExact(int)`, `longToByteExact(long)`
    * `intToShortExact(int)`, `longToShortExact(long)`
    * `floatToByteExact(float)`, `doubleToByteExact(double)`
    * `floatToShortExact(float)`, `doubleToShortExact(double)`
    
    `ByteExactNumeric` / `ShortExactNumeric` only expose same-type identity 
narrowing (their `toByte(byte)` / `toShort(short)` are trivial), so unlike the 
`int` / `long` targets refactored in #55934 — which delegate to 
`LongExactNumeric.toInt` / `FloatExactNumeric.toInt` / 
`DoubleExactNumeric.toInt` / `toLong` — there is no existing Scala object to 
route the byte/short narrowing through. The Java helper is the cleanest fit.
    
    `Cast.scala` changes:
    * `castIntegralTypeToIntegralTypeExactCode`: the `byte` / `short` branch 
(previously an inline 5-line if/throw block) emits a single 
`CastUtils.${integralPrefix(from)}To${target.capitalize}Exact($c)` call. The 
`int` branch (added in #55934) is unchanged.
    * `castFractionToIntegralTypeCode`: the `byte` / `short` branch (previously 
an inline 5-line floor/ceil block plus `lowerAndUpperBound`) emits a single 
`CastUtils.${fractionalPrefix(from)}To${target.capitalize}Exact($c)` call. The 
`int` / `long` branch (added in #55934) is unchanged. The now-unused 
`lowerAndUpperBound` Scala helper is removed.
    * Eval paths for `castToByte` and `castToShort` add ANSI cases for 
`ShortType` / `IntegerType` / `LongType` / `FloatType` / `DoubleType` source 
types that delegate to the new helpers, replacing the existing multi-line 
`exactNumeric.toInt(b) + bounds-check` body.
    * Two small `integralPrefix(from: DataType)` / `fractionalPrefix(from: 
DataType)` Scala helpers handle the method-name dispatch.
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). The byte/short narrowing ANSI bodies were 5 
lines each across 8 codegen call sites; this PR collapses them to one line per 
call site, matching the int/long target work merged in #55934.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. The compiled behavior is identical; only the emitted Java source text 
changes.
    
    ### How was this patch tested?
    
    ```
    build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite 
*CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite"
    ```
    
    307/307 pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.x
    
    Closes #55935 from gengliangwang/SPARK-56910-cast-byte-short.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../spark/sql/catalyst/expressions/CastUtils.java  | 98 ++++++++++++++++++++++
 .../spark/sql/catalyst/expressions/Cast.scala      | 82 +++++++++---------
 2 files changed, 142 insertions(+), 38 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
new file mode 100644
index 000000000000..700f7e41d233
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.sql.errors.QueryExecutionErrors;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+
+/**
+ * Static helpers used by {@code Cast.doGenCode} (and corresponding eval
+ * paths) for ANSI overflow-checked narrowing to {@code byte} / {@code short}.
+ *
+ * <p>Narrowing to {@code int} / {@code long} is handled by calling the 
existing
+ * {@code LongExactNumeric} / {@code FloatExactNumeric} / {@code 
DoubleExactNumeric}
+ * Scala objects directly from codegen (see SPARK-56909). The helpers below
+ * cover {@code byte} / {@code short} only, since {@code ByteExactNumeric} /
+ * {@code ShortExactNumeric} don't expose a cross-type narrowing API.
+ *
+ * <p>The source and target {@link DataType} objects referenced by the overflow
+ * error message are held in {@code private static final} fields so the happy
+ * path performs no per-row {@code references[]} lookups.
+ */
+public final class CastUtils {
+
+  private CastUtils() {}
+
+  private static final DataType SHORT = DataTypes.ShortType;
+  private static final DataType INT = DataTypes.IntegerType;
+  private static final DataType LONG = DataTypes.LongType;
+  private static final DataType BYTE = DataTypes.ByteType;
+  private static final DataType FLOAT = DataTypes.FloatType;
+  private static final DataType DOUBLE = DataTypes.DoubleType;
+
+  // ----- integral narrowing (ANSI: throw on overflow) -----
+
+  public static byte shortToByteExact(short v) {
+    if (v == (byte) v) return (byte) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, SHORT, BYTE);
+  }
+
+  public static byte intToByteExact(int v) {
+    if (v == (byte) v) return (byte) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, INT, BYTE);
+  }
+
+  public static byte longToByteExact(long v) {
+    if (v == (byte) v) return (byte) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, BYTE);
+  }
+
+  public static short intToShortExact(int v) {
+    if (v == (short) v) return (short) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, INT, SHORT);
+  }
+
+  public static short longToShortExact(long v) {
+    if (v == (short) v) return (short) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, SHORT);
+  }
+
+  // ----- fractional -> integral (ANSI: throw on overflow) -----
+  // Mirrors castFractionToIntegralTypeCode: floor(v) <= MAX && ceil(v) >= MIN.
+
+  public static byte floatToByteExact(float v) {
+    if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE) 
return (byte) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, BYTE);
+  }
+
+  public static byte doubleToByteExact(double v) {
+    if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE) 
return (byte) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, BYTE);
+  }
+
+  public static short floatToShortExact(float v) {
+    if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) 
return (short) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, SHORT);
+  }
+
+  public static short doubleToShortExact(double v) {
+    if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) 
return (short) v;
+    throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, SHORT);
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 419ca3f32d88..0611c3e9bfb3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -984,6 +984,14 @@ case class Cast(
           errorOrNull(t, from, ShortType)
         }
       })
+    case IntegerType if ansiEnabled =>
+      b => CastUtils.intToShortExact(b.asInstanceOf[Int])
+    case LongType if ansiEnabled =>
+      b => CastUtils.longToShortExact(b.asInstanceOf[Long])
+    case FloatType if ansiEnabled =>
+      b => CastUtils.floatToShortExact(b.asInstanceOf[Float])
+    case DoubleType if ansiEnabled =>
+      b => CastUtils.doubleToShortExact(b.asInstanceOf[Double])
     case x: NumericType if ansiEnabled =>
       val exactNumeric = PhysicalNumericType.exactNumeric(x)
       b =>
@@ -1040,6 +1048,16 @@ case class Cast(
           errorOrNull(t, from, ByteType)
         }
       })
+    case ShortType if ansiEnabled =>
+      b => CastUtils.shortToByteExact(b.asInstanceOf[Short])
+    case IntegerType if ansiEnabled =>
+      b => CastUtils.intToByteExact(b.asInstanceOf[Int])
+    case LongType if ansiEnabled =>
+      b => CastUtils.longToByteExact(b.asInstanceOf[Long])
+    case FloatType if ansiEnabled =>
+      b => CastUtils.floatToByteExact(b.asInstanceOf[Float])
+    case DoubleType if ansiEnabled =>
+      b => CastUtils.doubleToByteExact(b.asInstanceOf[Double])
     case x: NumericType if ansiEnabled =>
       val exactNumeric = PhysicalNumericType.exactNumeric(x)
       b =>
@@ -1999,28 +2017,13 @@ case class Cast(
       }).getClass.getCanonicalName.stripSuffix("$")
       (c, evPrim, _) => code"$evPrim = $numericObj.toInt($c);"
     } else {
-      val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
-      val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
-      (c, evPrim, _) =>
-        code"""
-          if ($c == ($integralType) $c) {
-            $evPrim = ($integralType) $c;
-          } else {
-            throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, 
$toDt);
-          }
-        """
-    }
-  }
-
-
-  private[this] def lowerAndUpperBound(integralType: String): (String, String) 
= {
-    val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) 
match {
-      case "long" => (Long.MinValue, Long.MaxValue, "L")
-      case "int" => (Int.MinValue, Int.MaxValue, "")
-      case "short" => (Short.MinValue, Short.MaxValue, "")
-      case "byte" => (Byte.MinValue, Byte.MaxValue, "")
+      // Byte / short narrowing: call the matching CastUtils helper. Existing 
*ExactNumeric
+      // objects don't expose cross-type narrowing to byte / short (their 
toByte / toShort are
+      // same-type identities), so a Java helper is the cleanest fit.
+      val castUtils = classOf[CastUtils].getName
+      val method = s"${integralPrefix(from)}To${integralType.capitalize}Exact"
+      (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);"
     }
-    (min.toString + typeIndicator, max.toString + typeIndicator)
   }
 
   private[this] def castFractionToIntegralTypeCode(
@@ -2042,26 +2045,29 @@ case class Cast(
       val method = s"to${integralType.capitalize}"
       (c, evPrim, _) => code"$evPrim = $numericObj.$method($c);"
     } else {
-      val (min, max) = lowerAndUpperBound(integralType)
-      val mathClass = classOf[Math].getName
-      val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
-      val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
-      // When casting floating values to integral types, Spark uses the method 
`Numeric.toInt`
-      // Or `Numeric.toLong` directly. For positive floating values, it is 
equivalent to
-      // `Math.floor`; for negative floating values, it is equivalent to 
`Math.ceil`.
-      // So, we can use the condition `Math.floor(x) <= upperBound && 
Math.ceil(x) >= lowerBound`
-      // to check if the floating value x is in the range of an integral type 
after rounding.
-      (c, evPrim, _) =>
-        code"""
-          if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
-            $evPrim = ($integralType) $c;
-          } else {
-            throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, 
$toDt);
-          }
-        """
+      // Float / double -> byte / short: same rationale as the integral byte / 
short branch
+      // above -- no equivalent *ExactNumeric API, so route through CastUtils.
+      val castUtils = classOf[CastUtils].getName
+      val method = 
s"${fractionalPrefix(from)}To${integralType.capitalize}Exact"
+      (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);"
     }
   }
 
+  private[this] def integralPrefix(from: DataType): String = from match {
+    case ShortType => "short"
+    case IntegerType => "int"
+    case LongType => "long"
+    case _ => throw SparkException.internalError(
+      s"Unexpected source type $from for 
castIntegralTypeToIntegralTypeExactCode")
+  }
+
+  private[this] def fractionalPrefix(from: DataType): String = from match {
+    case FloatType => "float"
+    case DoubleType => "double"
+    case _ => throw SparkException.internalError(
+      s"Unexpected source type $from for castFractionToIntegralTypeCode")
+  }
+
   private[this] def castToByteCode(from: DataType, ctx: CodegenContext): 
CastFunction = from match {
     case _: StringType if ansiEnabled =>
       val stringUtils = 
UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to