This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a96e9ca81518 [SPARK-53968][SQL] Store decimal precision loss conf in 
arithmetic expressions
a96e9ca81518 is described below

commit a96e9ca81518bff31b0089d459fe78804ca1aa38
Author: Stefan Kandic <[email protected]>
AuthorDate: Wed Oct 22 21:07:33 2025 +0800

    [SPARK-53968][SQL] Store decimal precision loss conf in arithmetic 
expressions
    
    ### What changes were proposed in this pull request?
    Currently, arithmetic expressions such as `Add` and `Multiply` use the 
configuration `spark.sql.decimalOperations.allowPrecisionLoss` to determine 
their output type when working with decimal values. This approach is 
problematic because if the expression is transformed or copied, its return type 
could change depending on the active configuration value.
    
    This issue can happen during view resolution; we can use one value of the 
config during analysis and different one during query optimization. If a 
referenced expression changes type and that reference is reused elsewhere in 
the plan it will trigger a plan validation error.
    
    ### Why are the changes needed?
    To address this, we should follow a similar approach to what was done for 
ANSI mode: store the relevant context directly within the expression as part of 
its state. This ensures the expression remains stable and unaffected by 
configuration changes when it’s copied or transformed. To make this transition 
smooth, I’ve generalized the existing EvalMode used for ANSI so that it can be 
extended to multiple configuration dimensions.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added a new unit test in SQLViewSuite which was failing with plan 
validation error before.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #52681 from stefankandic/fixViewDec.
    
    Authored-by: Stefan Kandic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../BinaryArithmeticWithDatetimeResolver.scala     | 36 ++++-----
 .../sql/catalyst/expressions/Expression.scala      |  8 +-
 .../catalyst/expressions/NumericEvalContext.scala  | 57 ++++++++++++++
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 23 +++---
 .../sql/catalyst/expressions/arithmetic.scala      | 88 ++++++++++++++++------
 .../catalyst/expressions/bitwiseExpressions.scala  | 12 ++-
 .../expressions/ArithmeticExpressionSuite.scala    | 59 ++++++++-------
 .../spark/sql/SparkSessionExtensionSuite.scala     |  3 +-
 .../connector/functions/V2FunctionBenchmark.scala  |  4 +-
 .../apache/spark/sql/execution/SQLViewSuite.scala  | 59 +++++++++++++++
 10 files changed, 262 insertions(+), 87 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala
index 08407bbe96cc..cdfd942ca09a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/BinaryArithmeticWithDatetimeResolver.scala
@@ -63,7 +63,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.DAY
 
 object BinaryArithmeticWithDatetimeResolver {
   def resolve(expr: Expression): Expression = expr match {
-    case a @ Add(l, r, mode) =>
+    case a @ Add(l, r, context) =>
       (l.dataType, r.dataType) match {
         case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, 
ExtractANSIIntervalDays(r))
         case (DateType, _: DayTimeIntervalType) => 
TimestampAddInterval(Cast(l, TimestampType), r)
@@ -83,7 +83,7 @@ object BinaryArithmeticWithDatetimeResolver {
         case (_: AnsiIntervalType, _: NullType) =>
           a.copy(right = Cast(a.right, a.left.dataType))
         case (DateType, CalendarIntervalType) =>
-          DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
+          DateAddInterval(l, r, ansiEnabled = context.evalMode == 
EvalMode.ANSI)
         case (_: TimeType, _: DayTimeIntervalType) => TimeAddInterval(l, r)
         case (_: DatetimeType, _: NullType) =>
           a.copy(right = Cast(a.right, DayTimeIntervalType.DEFAULT))
@@ -93,24 +93,26 @@ object BinaryArithmeticWithDatetimeResolver {
         case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
           Cast(TimestampAddInterval(l, r), l.dataType)
         case (CalendarIntervalType, DateType) =>
-          DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
+          DateAddInterval(r, l, ansiEnabled = context.evalMode == 
EvalMode.ANSI)
         case (CalendarIntervalType | _: DayTimeIntervalType, _) =>
           Cast(TimestampAddInterval(r, l), r.dataType)
         case (DateType, dt) if dt != StringType => DateAdd(l, r)
         case (dt, DateType) if dt != StringType => DateAdd(r, l)
         case _ => a
       }
-    case s @ Subtract(l, r, mode) =>
+    case s @ Subtract(l, r, context) =>
       (l.dataType, r.dataType) match {
         case (DateType, DayTimeIntervalType(DAY, DAY)) =>
-          DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == 
EvalMode.ANSI))
+          DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), context.evalMode 
== EvalMode.ANSI))
         case (DateType, _: DayTimeIntervalType) =>
           DatetimeSub(l, r,
-            TimestampAddInterval(Cast(l, TimestampType), UnaryMinus(r, mode == 
EvalMode.ANSI)))
+            TimestampAddInterval(Cast(l, TimestampType),
+              UnaryMinus(r, context.evalMode == EvalMode.ANSI)))
         case (DateType, _: YearMonthIntervalType) =>
-          DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == 
EvalMode.ANSI)))
+          DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, 
context.evalMode == EvalMode.ANSI)))
         case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
-          DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == 
EvalMode.ANSI)))
+          DatetimeSub(l, r, TimestampAddYMInterval(l,
+            UnaryMinus(r, context.evalMode == EvalMode.ANSI)))
         case (CalendarIntervalType, CalendarIntervalType) |
              (_: DayTimeIntervalType, _: DayTimeIntervalType) =>
           s
@@ -124,15 +126,15 @@ object BinaryArithmeticWithDatetimeResolver {
             r,
             DateAddInterval(
               l,
-              UnaryMinus(r, mode == EvalMode.ANSI),
-              ansiEnabled = mode == EvalMode.ANSI
+              UnaryMinus(r, context.evalMode == EvalMode.ANSI),
+              ansiEnabled = context.evalMode == EvalMode.ANSI
             )
           )
         case (_: TimeType, _: DayTimeIntervalType) =>
-          DatetimeSub(l, r, TimeAddInterval(l, UnaryMinus(r, mode == 
EvalMode.ANSI)))
+          DatetimeSub(l, r, TimeAddInterval(l, UnaryMinus(r, context.evalMode 
== EvalMode.ANSI)))
         case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
           Cast(DatetimeSub(l, r,
-            TimestampAddInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))), 
l.dataType)
+            TimestampAddInterval(l, UnaryMinus(r, context.evalMode == 
EvalMode.ANSI))), l.dataType)
         case _
           if AnyTimestampTypeExpression.unapply(l) ||
             AnyTimestampTypeExpression.unapply(r) =>
@@ -142,19 +144,19 @@ object BinaryArithmeticWithDatetimeResolver {
         case (_: TimeType, _: TimeType) => SubtractTimes(l, r)
         case _ => s
       }
-    case m @ Multiply(l, r, mode) =>
+    case m @ Multiply(l, r, context) =>
       (l.dataType, r.dataType) match {
-        case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == 
EvalMode.ANSI)
-        case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == 
EvalMode.ANSI)
+        case (CalendarIntervalType, _) => MultiplyInterval(l, r, 
context.evalMode == EvalMode.ANSI)
+        case (_, CalendarIntervalType) => MultiplyInterval(r, l, 
context.evalMode == EvalMode.ANSI)
         case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
         case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
         case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
         case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
         case _ => m
       }
-    case d @ Divide(l, r, mode) =>
+    case d @ Divide(l, r, context) =>
       (l.dataType, r.dataType) match {
-        case (CalendarIntervalType, _) => DivideInterval(l, r, mode == 
EvalMode.ANSI)
+        case (CalendarIntervalType, _) => DivideInterval(l, r, 
context.evalMode == EvalMode.ANSI)
         case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
         case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
         case _ => d
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index f706741fc98c..b61f7ee0ee16 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -1423,13 +1423,13 @@ trait CommutativeExpression extends Expression {
   protected def buildCanonicalizedPlan(
       collectOperands: PartialFunction[Expression, Seq[Expression]],
       buildBinaryOp: (Expression, Expression) => Expression,
-      evalMode: Option[EvalMode.Value] = None): Expression = {
+      evalContext: Option[NumericEvalContext] = None): Expression = {
     val operands = orderCommutative(collectOperands)
     val reorderResult =
       if (operands.length < 
SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) {
         operands.reduce(buildBinaryOp)
       } else {
-        MultiCommutativeOp(operands, this.getClass, evalMode)(this)
+        MultiCommutativeOp(operands, this.getClass, evalContext)(this)
       }
     reorderResult
   }
@@ -1446,7 +1446,7 @@ trait CommutativeExpression extends Expression {
  *      Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
  * @param operands A sequence of operands that produces a commutative 
expression tree.
  * @param opCls The class of the root operator of the expression tree.
- * @param evalMode The optional expression evaluation mode.
+ * @param evalContext The optional expression evaluation context.
  * @param originalRoot Root operator of the commutative expression tree before 
canonicalization.
  *                     This object reference is used to deduce the return 
dataType of Add and
  *                     Multiply operations when the input datatype is decimal.
@@ -1454,7 +1454,7 @@ trait CommutativeExpression extends Expression {
 case class MultiCommutativeOp(
     operands: Seq[Expression],
     opCls: Class[_],
-    evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends 
Unevaluable {
+    evalContext: Option[NumericEvalContext])(originalRoot: Expression) extends 
Unevaluable {
   // Helper method to deduce the data type of a single operation.
   private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
     originalRoot match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NumericEvalContext.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NumericEvalContext.scala
new file mode 100644
index 000000000000..28ec58d42475
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/NumericEvalContext.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Encapsulates the evaluation context for expressions, capturing SQL 
configuration
+ * state at expression construction time.
+ *
+ * This context must be stored as part of the expression's state to ensure 
deterministic
+ * evaluation. Without it, copying an expression or evaluating it in a 
different context
+ * (e.g., inside a view) could produce different results due to changed SQL 
configuration
+ * values.
+ *
+ * @param evalMode                  The error handling mode (LEGACY, ANSI, or 
TRY) that determines
+ *                                  overflow behavior and exception handling 
for operations like
+ *                                  arithmetic and casts.
+ * @param allowDecimalPrecisionLoss Whether decimal operations are allowed to 
lose precision
+ *                                  when the result type cannot represent the 
full precision.
+ *                                  Corresponds to
+ *                                  
spark.sql.decimalOperations.allowPrecisionLoss.
+ */
+case class NumericEvalContext private(
+    evalMode: EvalMode.Value,
+    allowDecimalPrecisionLoss: Boolean
+)
+
+case object NumericEvalContext {
+
+  def apply(
+      evalMode: EvalMode.Value,
+      allowDecimalPrecisionLoss: Boolean = 
SQLConf.get.decimalOperationsAllowPrecisionLoss
+  ): NumericEvalContext = {
+    new NumericEvalContext(evalMode, allowDecimalPrecisionLoss)
+  }
+
+  def fromSQLConf(conf: SQLConf): NumericEvalContext = {
+    NumericEvalContext(
+      EvalMode.fromSQLConf(conf),
+      conf.decimalOperationsAllowPrecisionLoss)
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index dfd41ad12a28..d066a87fc791 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -43,19 +43,19 @@ import org.apache.spark.sql.types._
   since = "1.0.0")
 case class Sum(
     child: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get))
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
   extends DeclarativeAggregate
   with ImplicitCastInputTypes
   with UnaryLike[Expression]
   with SupportQueryContext {
 
-  def this(child: Expression) = this(child, EvalMode.fromSQLConf(SQLConf.get))
+  def this(child: Expression) = this(child, 
NumericEvalContext.fromSQLConf(SQLConf.get))
 
   private def shouldTrackIsEmpty: Boolean = resultType match {
     case _: DecimalType => true
     // For try_sum(), the result of following data types can be null on 
overflow.
     // Thus we need additional buffer to keep track of whether overflow 
happens.
-    case _: IntegralType | _: AnsiIntervalType if evalMode == EvalMode.TRY => 
true
+    case _: IntegralType | _: AnsiIntervalType if evalContext.evalMode == 
EvalMode.TRY => true
     case _ => false
   }
 
@@ -89,7 +89,7 @@ case class Sum(
 
   private def add(left: Expression, right: Expression): Expression = 
left.dataType match {
     case _: DecimalType => DecimalAddNoOverflowCheck(left, right, 
left.dataType)
-    case _ => Add(left, right, evalMode)
+    case _ => Add(left, right, evalContext)
   }
 
   override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) {
@@ -176,7 +176,7 @@ case class Sum(
     resultType match {
       case d: DecimalType =>
         val checkOverflowInSum =
-          CheckOverflowInSum(sum, d, evalMode != EvalMode.ANSI, 
getContextOrNull())
+          CheckOverflowInSum(sum, d, evalContext.evalMode != EvalMode.ANSI, 
getContextOrNull())
         If(isEmpty, Literal.create(null, resultType), checkOverflowInSum)
       case _ if shouldTrackIsEmpty =>
         If(isEmpty, Literal.create(null, resultType), sum)
@@ -187,11 +187,12 @@ case class Sum(
   // The flag `evalMode` won't be shown in the `toString` or `toAggString` 
methods
   override def flatArguments: Iterator[Any] = Iterator(child)
 
-  override def initQueryContext(): Option[QueryContext] = if (evalMode == 
EvalMode.ANSI) {
-    Some(origin.context)
-  } else {
-    None
-  }
+  override def initQueryContext(): Option[QueryContext] =
+    if (evalContext.evalMode == EvalMode.ANSI) {
+      Some(origin.context)
+    } else {
+      None
+    }
 
   override protected def withNewChildInternal(newChild: Expression): 
Expression =
     copy(child = newChild)
@@ -218,7 +219,7 @@ object TrySumExpressionBuilder extends ExpressionBuilder {
   override def build(funcName: String, expressions: Seq[Expression]): 
Expression = {
     val numArgs = expressions.length
     if (numArgs == 1) {
-      Sum(expressions.head, EvalMode.TRY)
+      Sum(expressions.head, NumericEvalContext(EvalMode.TRY))
     } else {
       throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), 
numArgs)
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 032e04ce2cdd..1c93a6586761 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -196,7 +196,9 @@ abstract class BinaryArithmetic extends BinaryOperator with 
SupportQueryContext
   override def contextIndependentFoldable: Boolean =
     left.contextIndependentFoldable && right.contextIndependentFoldable
 
-  protected val evalMode: EvalMode.Value
+  val evalContext: NumericEvalContext
+
+  def evalMode: EvalMode.Value = evalContext.evalMode
 
   private lazy val internalDataType: DataType = (left.dataType, 
right.dataType) match {
     case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
@@ -224,7 +226,7 @@ abstract class BinaryArithmetic extends BinaryOperator with 
SupportQueryContext
   // When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if 
the precision / scale
   // needed are out of the range of available values, the scale is reduced up 
to 6, in order to
   // prevent the truncation of the integer part of the decimals.
-  protected def allowPrecisionLoss: Boolean = 
SQLConf.get.decimalOperationsAllowPrecisionLoss
+  protected def allowPrecisionLoss: Boolean = 
evalContext.allowDecimalPrecisionLoss
 
   protected def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): 
DecimalType = {
     throw SparkException.internalError(
@@ -405,11 +407,12 @@ object BinaryArithmetic {
 case class Add(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends BinaryArithmetic
   with CommutativeExpression {
 
   def this(left: Expression, right: Expression) =
-    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
+    this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -465,9 +468,9 @@ case class Add(
 
   override lazy val canonicalized: Expression = {
     val reorderResult = buildCanonicalizedPlan(
-      { case Add(l, r, em) if em == evalMode => Seq(l, r) },
-      { case (l: Expression, r: Expression) => Add(l, r, evalMode)},
-      Some(evalMode)
+      { case Add(l, r, em) if em == evalContext => Seq(l, r) },
+      { case (l: Expression, r: Expression) => Add(l, r, evalContext)},
+      Some(evalContext)
     )
     if (resolved && reorderResult.resolved && reorderResult.dataType == 
dataType) {
       reorderResult
@@ -479,6 +482,11 @@ case class Add(
   }
 }
 
+object Add {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
Add =
+    new Add(left, right, NumericEvalContext(evalMode))
+}
+
 @ExpressionDescription(
   usage = "expr1 _FUNC_ expr2 - Returns `expr1`-`expr2`.",
   examples = """
@@ -491,10 +499,11 @@ case class Add(
 case class Subtract(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic {
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends BinaryArithmetic {
 
   def this(left: Expression, right: Expression) =
-    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
+    this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -555,6 +564,11 @@ case class Subtract(
     newLeft: Expression, newRight: Expression): Subtract = copy(left = 
newLeft, right = newRight)
 }
 
+object Subtract {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
Subtract =
+    new Subtract(left, right, NumericEvalContext(evalMode))
+}
+
 @ExpressionDescription(
   usage = "expr1 _FUNC_ expr2 - Returns `expr1`*`expr2`.",
   examples = """
@@ -567,11 +581,12 @@ case class Subtract(
 case class Multiply(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends BinaryArithmetic
   with CommutativeExpression {
 
   def this(left: Expression, right: Expression) =
-    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
+    this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = NumericType
 
@@ -620,13 +635,18 @@ case class Multiply(
 
   override lazy val canonicalized: Expression = {
     buildCanonicalizedPlan(
-      { case Multiply(l, r, em) if em == evalMode => Seq(l, r) },
-      { case (l: Expression, r: Expression) => Multiply(l, r, evalMode) },
-      Some(evalMode)
+      { case Multiply(l, r, ec) if ec == evalContext => Seq(l, r) },
+      { case (l: Expression, r: Expression) => Multiply(l, r, evalContext) },
+      Some(evalContext)
     )
   }
 }
 
+object Multiply {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
Multiply =
+    new Multiply(left, right, NumericEvalContext(evalMode))
+}
+
 // Common base trait for Divide and Remainder, since these two classes are 
almost identical
 trait DivModLike extends BinaryArithmetic {
 
@@ -779,10 +799,11 @@ trait DivModLike extends BinaryArithmetic {
 case class Divide(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
DivModLike {
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends DivModLike {
 
   def this(left: Expression, right: Expression) =
-    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
+    this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get))
 
   // `try_divide` has exactly the same behavior as the legacy divide, so here 
it only executes
   // the error code path when `evalMode` is `ANSI`.
@@ -834,6 +855,11 @@ case class Divide(
     newLeft: Expression, newRight: Expression): Divide = copy(left = newLeft, 
right = newRight)
 }
 
+object Divide {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
Divide =
+    new Divide(left, right, NumericEvalContext(evalMode))
+}
+
 // scalastyle:off line.size.limit
 @ExpressionDescription(
   usage = "expr1 _FUNC_ expr2 - Divide `expr1` by `expr2`. It returns NULL if 
an operand is NULL or `expr2` is 0. The result is casted to long.",
@@ -850,10 +876,11 @@ case class Divide(
 case class IntegralDivide(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
DivModLike {
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends DivModLike {
 
   def this(left: Expression, right: Expression) = this(left, right,
-    EvalMode.fromSQLConf(SQLConf.get))
+    NumericEvalContext.fromSQLConf(SQLConf.get))
 
   override def checkDivideOverflow: Boolean = left.dataType match {
     case LongType if failOnError => true
@@ -912,6 +939,11 @@ case class IntegralDivide(
     copy(left = newLeft, right = newRight)
 }
 
+object IntegralDivide {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
IntegralDivide =
+    new IntegralDivide(left, right, NumericEvalContext(evalMode))
+}
+
 @ExpressionDescription(
   usage = "expr1 % expr2, or mod(expr1, expr2) - Returns the remainder after 
`expr1`/`expr2`.",
   examples = """
@@ -926,10 +958,11 @@ case class IntegralDivide(
 case class Remainder(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
DivModLike {
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends DivModLike {
 
   def this(left: Expression, right: Expression) =
-    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
+    this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get))
 
   override def inputType: AbstractDataType = NumericType
 
@@ -994,6 +1027,11 @@ case class Remainder(
     newLeft: Expression, newRight: Expression): Remainder = copy(left = 
newLeft, right = newRight)
 }
 
+object Remainder {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
Remainder =
+    new Remainder(left, right, NumericEvalContext(evalMode))
+}
+
 @ExpressionDescription(
   usage = "_FUNC_(expr1, expr2) - Returns the positive value of `expr1` mod 
`expr2`.",
   examples = """
@@ -1008,10 +1046,11 @@ case class Remainder(
 case class Pmod(
     left: Expression,
     right: Expression,
-    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends 
BinaryArithmetic {
+    evalContext: NumericEvalContext = 
NumericEvalContext.fromSQLConf(SQLConf.get))
+  extends BinaryArithmetic {
 
   def this(left: Expression, right: Expression) =
-    this(left, right, EvalMode.fromSQLConf(SQLConf.get))
+    this(left, right, NumericEvalContext.fromSQLConf(SQLConf.get))
 
   override def toString: String = s"pmod($left, $right)"
 
@@ -1199,6 +1238,11 @@ case class Pmod(
     copy(left = newLeft, right = newRight)
 }
 
+object Pmod {
+  def apply(left: Expression, right: Expression, evalMode: EvalMode.Value): 
Pmod =
+    new Pmod(left, right, NumericEvalContext(evalMode))
+}
+
 /**
  * A function that returns the least value of all parameters, skipping null 
values.
  * It takes at least 2 parameters, and returns null iff all parameters are 
null.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 26743ca6ff15..5fac0a93bf9b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -39,7 +39,9 @@ import org.apache.spark.sql.types._
 case class BitwiseAnd(left: Expression, right: Expression) extends 
BinaryArithmetic
   with CommutativeExpression {
 
-  protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
+  override def evalMode: EvalMode.Value = EvalMode.LEGACY
+
+  override val evalContext: NumericEvalContext = NumericEvalContext(evalMode)
 
   override def inputType: AbstractDataType = IntegralType
 
@@ -86,7 +88,9 @@ case class BitwiseAnd(left: Expression, right: Expression) 
extends BinaryArithme
 case class BitwiseOr(left: Expression, right: Expression) extends 
BinaryArithmetic
   with CommutativeExpression {
 
-  protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
+  override def evalMode: EvalMode.Value = EvalMode.LEGACY
+
+  override val evalContext: NumericEvalContext = NumericEvalContext(evalMode)
 
   override def inputType: AbstractDataType = IntegralType
 
@@ -133,7 +137,9 @@ case class BitwiseOr(left: Expression, right: Expression) 
extends BinaryArithmet
 case class BitwiseXor(left: Expression, right: Expression) extends 
BinaryArithmetic
   with CommutativeExpression {
 
-  protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
+  override def evalMode: EvalMode.Value = EvalMode.LEGACY
+
+  override val evalContext: NumericEvalContext = NumericEvalContext(evalMode)
 
   override def inputType: AbstractDataType = IntegralType
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 649ce2478825..72be3031ace6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -285,34 +285,35 @@ class ArithmeticExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
       val n1 = makeNum(p1, s1)
       val n2 = makeNum(p2, s2)
 
-      val mulActual = Multiply(
-        Literal(Decimal(BigDecimal(n1), p1, s1)),
-        Literal(Decimal(BigDecimal(n2), p2, s2))
-      )
-      val mulExact = new java.math.BigDecimal(n1).multiply(new 
java.math.BigDecimal(n2))
-
-      val divActual = Divide(
-        Literal(Decimal(BigDecimal(n1), p1, s1)),
-        Literal(Decimal(BigDecimal(n2), p2, s2))
-      )
-      val divExact = new java.math.BigDecimal(n1)
-        .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)
-
-      val remActual = Remainder(
-        Literal(Decimal(BigDecimal(n1), p1, s1)),
-        Literal(Decimal(BigDecimal(n2), p2, s2))
-      )
-      val remExact = new java.math.BigDecimal(n1).remainder(new 
java.math.BigDecimal(n2))
-
-      val quotActual = IntegralDivide(
-        Literal(Decimal(BigDecimal(n1), p1, s1)),
-        Literal(Decimal(BigDecimal(n2), p2, s2))
-      )
-      val quotExact =
-        new java.math.BigDecimal(n1).divideToIntegralValue(new 
java.math.BigDecimal(n2))
-
       Seq(true, false).foreach { allowPrecLoss =>
         withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> 
allowPrecLoss.toString) {
+          val mulActual = Multiply(
+            Literal(Decimal(BigDecimal(n1), p1, s1)),
+            Literal(Decimal(BigDecimal(n2), p2, s2))
+          )
+          val mulExact = new java.math.BigDecimal(n1).multiply(new 
java.math.BigDecimal(n2))
+
+          val divActual = Divide(
+            Literal(Decimal(BigDecimal(n1), p1, s1)),
+            Literal(Decimal(BigDecimal(n2), p2, s2))
+          )
+          val divExact = new java.math.BigDecimal(n1)
+            .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)
+
+          val remActual = Remainder(
+            Literal(Decimal(BigDecimal(n1), p1, s1)),
+            Literal(Decimal(BigDecimal(n2), p2, s2))
+          )
+          val remExact = new java.math.BigDecimal(n1).remainder(new 
java.math.BigDecimal(n2))
+
+          val quotActual = IntegralDivide(
+            Literal(Decimal(BigDecimal(n1), p1, s1)),
+            Literal(Decimal(BigDecimal(n2), p2, s2))
+          )
+          val quotExact =
+            new java.math.BigDecimal(n1).divideToIntegralValue(new 
java.math.BigDecimal(n2))
+
+
           val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2)
           val mulResult = Decimal(mulExact.setScale(mulType.scale, 
RoundingMode.HALF_UP))
           val mulExpected =
@@ -483,7 +484,11 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
   }
 
   test("Remainder/Pmod: exception should contain SQL text context") {
-    Seq(("%", Remainder), ("pmod", Pmod)).foreach { case (symbol, exprBuilder) 
=>
+    type BinaryOpFn = (Expression, Expression, EvalMode.Value) => 
BinaryArithmetic
+    Seq[(String, BinaryOpFn)](
+      ("%", Remainder.apply),
+      ("pmod", Pmod.apply)
+    ).foreach { case (symbol, exprBuilder) =>
       val query = s"1L $symbol 0L"
       val o = Origin(
         line = Some(1),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index c1c041509c35..6ee0029b6839 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -933,7 +933,8 @@ class BrokenColumnarAdd(
     left: ColumnarExpression,
     right: ColumnarExpression,
     failOnError: Boolean = false)
-  extends Add(left, right, EvalMode.fromBoolean(failOnError)) with 
ColumnarExpression {
+  extends Add(left, right, 
NumericEvalContext(EvalMode.fromBoolean(failOnError)))
+    with ColumnarExpression {
 
   override def supportsColumnar: Boolean = left.supportsColumnar && 
right.supportsColumnar
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
index 49997b5b0c18..592869968917 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
@@ -23,7 +23,7 @@ import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLo
 import org.apache.spark.benchmark.Benchmark
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, 
Expression}
+import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, 
Expression, NumericEvalContext}
 import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.classic.ClassicConversions._
@@ -109,7 +109,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark {
       left: Expression,
       right: Expression,
       override val nullable: Boolean) extends BinaryArithmetic {
-    protected override val evalMode: EvalMode.Value = EvalMode.LEGACY
+    override val evalContext: NumericEvalContext = 
NumericEvalContext(EvalMode.LEGACY)
     override def inputType: AbstractDataType = NumericType
     override def symbol: String = "+"
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
index f715353fd431..050a004a9353 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
@@ -1333,4 +1333,63 @@ abstract class SQLViewSuite extends QueryTest with 
SQLTestUtils {
       }
     }
   }
+
+  test("SPARK-53968 reading the view after allowPrecisionLoss is changed") {
+    import org.apache.spark.sql.internal.SQLConf
+    val partsTableName = "parts_tbl"
+    val ordersTableName = "orders_tbl"
+    val viewName = "view_spark_53968"
+    withTable(partsTableName, ordersTableName) {
+      spark.sql(s"""CREATE TABLE $partsTableName (
+           | part_number STRING
+           |) USING PARQUET
+           |""".stripMargin)
+      spark.sql(s"INSERT INTO $partsTableName VALUES ('part1'), ('part2')")
+
+      spark.sql(s"""CREATE TABLE $ordersTableName
+           |USING PARQUET AS
+           |SELECT * FROM VALUES
+           |('part1', CAST(100 AS DECIMAL(38,18)), CAST(NULL   AS 
DECIMAL(38,18))),
+           |('part2', CAST(100 AS DECIMAL(38,18)), CAST(0 AS DECIMAL(38,18))),
+           |('part3', CAST(200.23 AS DECIMAL(38,18)), CAST(100 AS 
DECIMAL(38,18)))
+           |AS t(part_number, unit_price, shipping_price);
+           |""".stripMargin)
+
+      Seq((true, false), (false, true)).foreach { case (oldValue, newValue) =>
+        withView(viewName) {
+          withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> 
oldValue.toString) {
+            spark.sql(s"""
+                 |CREATE OR REPLACE VIEW $viewName AS
+                 |WITH order_details AS (
+                 |  SELECT
+                 |    orders.part_number,
+                 |    orders.unit_price
+                 |      + COALESCE(orders.shipping_price, CAST(0 AS 
DECIMAL(38, 18)))
+                 |      AS total_price
+                 |  FROM $ordersTableName orders
+                 |)
+                 |SELECT
+                 |  od.total_price
+                 |FROM order_details od LEFT JOIN $partsTableName pt
+                 |  ON pt.part_number = od.part_number
+                 |ORDER BY od.total_price
+            """.stripMargin)
+
+            val expectedResults = Seq(
+              Row(BigDecimal("100.00000000000000000")),
+              Row(BigDecimal("100.00000000000000000")),
+              Row(BigDecimal("300.23000000000000000")))
+
+            checkAnswer(spark.sql(s"SELECT * FROM $viewName"), expectedResults)
+
+            // Re-run the query with new value of the config, we should get 
the same result.
+            withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> 
newValue.toString) {
+
+              checkAnswer(spark.sql(s"SELECT * FROM $viewName"), 
expectedResults)
+            }
+          }
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to