This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new ca7fa663a2b2 [SPARK-56914][SQL] Simplify decimal arithmetic codegen 
under ANSI mode
ca7fa663a2b2 is described below

commit ca7fa663a2b25f6fd950f65ac0201f603c6ce0ce
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri May 29 09:06:29 2026 -0700

    [SPARK-56914][SQL] Simplify decimal arithmetic codegen under ANSI mode
    
    ### What changes were proposed in this pull request?
    
    Use `CastUtils.changePrecisionExact` / `changePrecisionOrNull` (added in 
SPARK-56911) from the `DecimalType.Fixed` codegen branches of:
    * `BinaryArithmetic.doGenCode` (covers `Add` / `Subtract` / `Multiply` on 
`Decimal`).
    * `BinaryDivModLike.doGenCode` (covers `Divide` / `IntegralDivide` / 
`Remainder` / `Pmod` on `Decimal`).
    
    Each codegen call site goes from `eval1.$op(eval2).toPrecision(p, s, 
ROUND_HALF_UP, !failOnError, ctx)` + a 4-line null check to a single 
`CastUtils.changePrecision{Exact,OrNull}` call.
    
    The eval path (`BinaryArithmetic.checkDecimalOverflow`) is left as the 
original one-line `value.toPrecision(p, s, ROUND_HALF_UP, !failOnError, 
getContextOrNull())`. Per the review on #55938 — routing a one-line eval call 
through a new helper would just be a different route to the same logic without 
a real win.
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). Decimal arithmetic is widespread in TPC-DS 
plans, and the `BinaryArithmetic` Decimal branch was one of the longer ANSI 
codegen bodies still emitted inline.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    ```
    build/sbt "catalyst/testOnly *ArithmeticExpressionSuite *DecimalSuite"
    ```
    
    60/60 pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.x
    
    Closes #55939 from gengliangwang/SPARK-56914-decimal-arithmetic.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 251df783c65ed22e3388b2c04e337ce1bdc069ea)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../sql/catalyst/expressions/arithmetic.scala      | 52 +++++++++++++---------
 1 file changed, 32 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 348b45472c57..23fffb162a8f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -278,19 +278,21 @@ abstract class BinaryArithmetic extends BinaryOperator 
with SupportQueryContext
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
     case DecimalType.Fixed(precision, scale) =>
-      val errorContextCode = getContextOrNullCode(ctx, failOnError)
-      val updateIsNull = if (failOnError) {
-        ""
+      val castUtils = classOf[CastUtils].getName
+      if (failOnError) {
+        val errorContextCode = getContextOrNullCode(ctx)
+        defineCodeGen(ctx, ev, (eval1, eval2) =>
+          s"$castUtils.changePrecisionExact(" +
+            s"$eval1.$decimalMethod($eval2), $precision, $scale, 
$errorContextCode)")
       } else {
-        s"${ev.isNull} = ${ev.value} == null;"
+        nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+          s"""
+             |${ev.value} = $castUtils.changePrecisionOrNull(
+             |  $eval1.$decimalMethod($eval2), $precision, $scale);
+             |${ev.isNull} = ${ev.value} == null;
+           """.stripMargin
+        })
       }
-      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
-        s"""
-           |${ev.value} = $eval1.$decimalMethod($eval2).toPrecision(
-           |  $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, 
$errorContextCode);
-           |$updateIsNull
-       """.stripMargin
-      })
     case CalendarIntervalType =>
       val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
       defineCodeGen(ctx, ev, (eval1, eval2) => 
s"$iu.$calendarIntervalMethod($eval1, $eval2)")
@@ -706,16 +708,26 @@ trait DivModLike extends BinaryArithmetic {
     val errorContextCode = getContextOrNullCode(ctx, failOnError)
     val operation = super.dataType match {
       case DecimalType.Fixed(precision, scale) =>
+        val castUtils = classOf[CastUtils].getName
         val decimalValue = ctx.freshName("decimalValue")
-        s"""
-           |Decimal $decimalValue = 
${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(
-           |  $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, 
$errorContextCode);
-           |if ($decimalValue != null) {
-           |  ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
-           |} else {
-           |  ${ev.isNull} = true;
-           |}
-           |""".stripMargin
+        if (failOnError) {
+          s"""
+             |Decimal $decimalValue = $castUtils.changePrecisionExact(
+             |  ${eval1.value}.$decimalMethod(${eval2.value}), $precision, 
$scale,
+             |  $errorContextCode);
+             |${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
+             |""".stripMargin
+        } else {
+          s"""
+             |Decimal $decimalValue = $castUtils.changePrecisionOrNull(
+             |  ${eval1.value}.$decimalMethod(${eval2.value}), $precision, 
$scale);
+             |if ($decimalValue != null) {
+             |  ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
+             |} else {
+             |  ${ev.isNull} = true;
+             |}
+             |""".stripMargin
+        }
       case _ => s"${ev.value} = ($javaType)(${eval1.value} $symbol 
${eval2.value});"
     }
     val checkIntegralDivideOverflow = if (checkDivideOverflow) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to