This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 803ea9590145 [SPARK-54041][SQL] Refactor ParameterizedQuery arguments 
validation
803ea9590145 is described below

commit 803ea95901455d498b460ad7e52ed9ac94e118d9
Author: mihailoale-db <[email protected]>
AuthorDate: Wed Oct 29 20:27:41 2025 +0800

    [SPARK-54041][SQL] Refactor ParameterizedQuery arguments validation
    
    ### What changes were proposed in this pull request?
    In this issue I propose to refactor `ParameterizedQuery` arguments 
validation to `ParameterizedQueryArgumentsValidator` so it can be reused 
between single-pass and fixed-point analyzer implementations. We also remove 
one redundant case from the `ParameterizedQueryArgumentsValidator.isNotAllowed` 
(`Alias` shouldn't call `isNotAllowed` again as it introduces unnecessary 
overhead to the method) to improve performance.
    
    ### Why are the changes needed?
    To ease code maintenance between single-pass and fixed-point analyzer 
implementations.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests (refactor).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #52744 from mihailoale-db/refactorparamcheckargs.
    
    Authored-by: mihailoale-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../ParameterizedQueryArgumentsValidator.scala     | 59 ++++++++++++++++++++++
 .../spark/sql/catalyst/analysis/parameters.scala   | 21 ++------
 .../apache/spark/sql/classic/SparkSession.scala    | 32 ++++--------
 3 files changed, 73 insertions(+), 39 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ParameterizedQueryArgumentsValidator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ParameterizedQueryArgumentsValidator.scala
new file mode 100644
index 000000000000..69c68e5f2a5d
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ParameterizedQueryArgumentsValidator.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{
+  Alias,
+  CreateArray,
+  CreateMap,
+  CreateNamedStruct,
+  Expression,
+  Literal,
+  MapFromArrays,
+  MapFromEntries,
+  VariableReference
+}
+
+/**
+ * Object used to validate arguments of [[ParameterizedQuery]] nodes.
+ */
+object ParameterizedQueryArgumentsValidator {
+
+  /**
+   * Validates the list of provided arguments. In case there is a invalid 
argument, throws
+   * `INVALID_SQL_ARG` exception.
+   */
+  def apply(arguments: Iterable[(String, Expression)]): Unit = {
+    arguments.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
+      expr.failAnalysis(
+        errorClass = "INVALID_SQL_ARG",
+        messageParameters = Map("name" -> name))
+    }
+  }
+
+  /**
+   * Recursively checks the provided expression tree. In case there is an 
invalid expression type
+   * returns `false`. Otherwise, returns `true`.
+   */
+  private def isNotAllowed(expression: Expression): Boolean = 
expression.exists {
+    case _: Literal | _: CreateArray | _: CreateNamedStruct | _: CreateMap | 
_: MapFromArrays |
+        _: MapFromEntries | _: VariableReference | _: Alias =>
+      false
+    case _ => true
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
index b3200be95f69..bf9acb775ce1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, 
CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, 
MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, 
VariableReference}
+import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, 
SubqueryExpression, Unevaluable}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
SupervisingCommand}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, 
PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
@@ -173,19 +173,6 @@ object MoveParameterizedQueriesDown extends 
Rule[LogicalPlan] {
  * from the user-specified arguments.
  */
 object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
-  private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
-    def isNotAllowed(expr: Expression): Boolean = expr.exists {
-      case _: Literal | _: CreateArray | _: CreateNamedStruct |
-        _: CreateMap | _: MapFromArrays |  _: MapFromEntries | _: 
VariableReference => false
-      case a: Alias => isNotAllowed(a.child)
-      case _ => true
-    }
-    args.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
-      expr.failAnalysis(
-        errorClass = "INVALID_SQL_ARG",
-        messageParameters = Map("name" -> name))
-    }
-  }
 
   private def bind(p0: LogicalPlan)(f: PartialFunction[Expression, 
Expression]): LogicalPlan = {
     var stop = false
@@ -210,7 +197,7 @@ object BindParameters extends Rule[LogicalPlan] with 
QueryErrorsBase {
             s"must be equal to the number of argument values 
${argValues.length}.")
         }
         val args = argNames.zip(argValues).toMap
-        checkArgs(args)
+        ParameterizedQueryArgumentsValidator(args)
         bind(child) { case NamedParameter(name) if args.contains(name) => 
args(name) }
 
       case PosParameterizedQuery(child, args)
@@ -218,7 +205,7 @@ object BindParameters extends Rule[LogicalPlan] with 
QueryErrorsBase {
           args.forall(_.resolved) =>
 
         val indexedArgs = args.zipWithIndex
-        checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
+        ParameterizedQueryArgumentsValidator(indexedArgs.map(arg => 
(s"_${arg._2}", arg._1)))
 
         val positions = scala.collection.mutable.Set.empty[Int]
         bind(child) { case p @ PosParameter(pos) => positions.add(pos); p }
@@ -238,7 +225,7 @@ object BindParameters extends Rule[LogicalPlan] with 
QueryErrorsBase {
           val finalName = if (name.isEmpty) s"_$index" else name
           finalName -> arg
         }
-        checkArgs(allArgs)
+        ParameterizedQueryArgumentsValidator(allArgs)
 
         // Collect parameter types used in the query to enforce invariants
         var hasNamedParam = false
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
index c6af76564b76..9166bc946b54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
@@ -40,9 +40,9 @@ import org.apache.spark.sql
 import org.apache.spark.sql.{AnalysisException, Artifact, 
DataSourceRegistration, Encoder, Encoders, ExperimentalMethods, Row, 
SparkSessionBuilder, SparkSessionCompanion, SparkSessionExtensions, 
SparkSessionExtensionsProvider, UDTFRegistration}
 import org.apache.spark.sql.artifact.ArtifactManager
 import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery, 
NameParameterizedQuery, ParameterizedQueryArgumentsValidator, 
PosParameterizedQuery, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
CreateArray, CreateMap, CreateNamedStruct, Expression, Literal, MapFromArrays, 
MapFromEntries, VariableReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
Expression, Literal}
 import org.apache.spark.sql.catalyst.parser.{HybridParameterContext, 
NamedParameterContext, ParserInterface, PositionalParameterContext}
 import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, 
LocalRelation, OneRowRelation, Project, Range}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
@@ -453,27 +453,15 @@ class SparkSession private(
     // Analyze the plan to resolve expressions
     val analyzed = sessionState.analyzer.execute(fakePlan)
 
-    // Validate: the expression tree must only contain allowed expression 
types.
-    // This mirrors the validation in BindParameters.checkArgs.
-    // We check this BEFORE optimization to catch unsupported functions like 
str_to_map.
-    def isNotAllowed(expr: Expression): Boolean = expr.exists {
-      case _: Literal | _: CreateArray | _: CreateNamedStruct |
-        _: CreateMap | _: MapFromArrays | _: MapFromEntries | _: 
VariableReference => false
-      case a: Alias => isNotAllowed(a.child)
-      case _ => true
-    }
-
-    analyzed.asInstanceOf[Project].projectList.foreach { alias =>
-      val optimizedExpr = alias.asInstanceOf[Alias].child
-      if (isNotAllowed(optimizedExpr)) {
-        // Both modern and legacy modes use INVALID_SQL_ARG for sql() API 
argument validation.
-        // UNSUPPORTED_EXPR_FOR_PARAMETER is reserved for EXECUTE IMMEDIATE.
-        throw new AnalysisException(
-          errorClass = "INVALID_SQL_ARG",
-          messageParameters = Map("name" -> alias.name),
-          origin = optimizedExpr.origin)
-      }
+    val expressionsToValidate = analyzed.asInstanceOf[Project].projectList.map 
{
+      case alias: Alias =>
+        (alias.name, alias.child)
+      case other =>
+        throw SparkException.internalError(
+          s"Expected an Alias, but got ${other.getClass.getSimpleName}"
+        )
     }
+    ParameterizedQueryArgumentsValidator(expressionsToValidate)
 
     // Optimize to constant-fold expressions. After optimization, all allowed 
expressions
     // should be folded to Literals.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to