This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 803ea9590145 [SPARK-54041][SQL] Refactor ParameterizedQuery arguments
validation
803ea9590145 is described below
commit 803ea95901455d498b460ad7e52ed9ac94e118d9
Author: mihailoale-db <[email protected]>
AuthorDate: Wed Oct 29 20:27:41 2025 +0800
[SPARK-54041][SQL] Refactor ParameterizedQuery arguments validation
### What changes were proposed in this pull request?
In this issue I propose to refactor `ParameterizedQuery` arguments
validation to `ParameterizedQueryArgumentsValidator` so it can be reused
between single-pass and fixed-point analyzer implementations. We also remove
one redundant case from the `ParameterizedQueryArgumentsValidator.isNotAllowed`
(`Alias` shouldn't call `isNotAllowed` again as it introduces unnecessary
overhead to the method) to improve performance.
### Why are the changes needed?
To ease code maintenance between single-pass and fixed-point analyzer
implementations.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests (refactor).
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52744 from mihailoale-db/refactorparamcheckargs.
Authored-by: mihailoale-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../ParameterizedQueryArgumentsValidator.scala | 59 ++++++++++++++++++++++
.../spark/sql/catalyst/analysis/parameters.scala | 21 ++------
.../apache/spark/sql/classic/SparkSession.scala | 32 ++++--------
3 files changed, 73 insertions(+), 39 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ParameterizedQueryArgumentsValidator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ParameterizedQueryArgumentsValidator.scala
new file mode 100644
index 000000000000..69c68e5f2a5d
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ParameterizedQueryArgumentsValidator.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{
+ Alias,
+ CreateArray,
+ CreateMap,
+ CreateNamedStruct,
+ Expression,
+ Literal,
+ MapFromArrays,
+ MapFromEntries,
+ VariableReference
+}
+
+/**
+ * Object used to validate arguments of [[ParameterizedQuery]] nodes.
+ */
+object ParameterizedQueryArgumentsValidator {
+
+ /**
+ * Validates the list of provided arguments. In case there is a invalid
argument, throws
+ * `INVALID_SQL_ARG` exception.
+ */
+ def apply(arguments: Iterable[(String, Expression)]): Unit = {
+ arguments.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
+ expr.failAnalysis(
+ errorClass = "INVALID_SQL_ARG",
+ messageParameters = Map("name" -> name))
+ }
+ }
+
+ /**
+ * Recursively checks the provided expression tree. In case there is an
invalid expression type
+ * returns `false`. Otherwise, returns `true`.
+ */
+ private def isNotAllowed(expression: Expression): Boolean =
expression.exists {
+ case _: Literal | _: CreateArray | _: CreateNamedStruct | _: CreateMap |
_: MapFromArrays |
+ _: MapFromEntries | _: VariableReference | _: Alias =>
+ false
+ case _ => true
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
index b3200be95f69..bf9acb775ce1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray,
CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal,
MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable,
VariableReference}
+import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression,
SubqueryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
SupervisingCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER,
PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
@@ -173,19 +173,6 @@ object MoveParameterizedQueriesDown extends
Rule[LogicalPlan] {
* from the user-specified arguments.
*/
object BindParameters extends Rule[LogicalPlan] with QueryErrorsBase {
- private def checkArgs(args: Iterable[(String, Expression)]): Unit = {
- def isNotAllowed(expr: Expression): Boolean = expr.exists {
- case _: Literal | _: CreateArray | _: CreateNamedStruct |
- _: CreateMap | _: MapFromArrays | _: MapFromEntries | _:
VariableReference => false
- case a: Alias => isNotAllowed(a.child)
- case _ => true
- }
- args.find(arg => isNotAllowed(arg._2)).foreach { case (name, expr) =>
- expr.failAnalysis(
- errorClass = "INVALID_SQL_ARG",
- messageParameters = Map("name" -> name))
- }
- }
private def bind(p0: LogicalPlan)(f: PartialFunction[Expression,
Expression]): LogicalPlan = {
var stop = false
@@ -210,7 +197,7 @@ object BindParameters extends Rule[LogicalPlan] with
QueryErrorsBase {
s"must be equal to the number of argument values
${argValues.length}.")
}
val args = argNames.zip(argValues).toMap
- checkArgs(args)
+ ParameterizedQueryArgumentsValidator(args)
bind(child) { case NamedParameter(name) if args.contains(name) =>
args(name) }
case PosParameterizedQuery(child, args)
@@ -218,7 +205,7 @@ object BindParameters extends Rule[LogicalPlan] with
QueryErrorsBase {
args.forall(_.resolved) =>
val indexedArgs = args.zipWithIndex
- checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
+ ParameterizedQueryArgumentsValidator(indexedArgs.map(arg =>
(s"_${arg._2}", arg._1)))
val positions = scala.collection.mutable.Set.empty[Int]
bind(child) { case p @ PosParameter(pos) => positions.add(pos); p }
@@ -238,7 +225,7 @@ object BindParameters extends Rule[LogicalPlan] with
QueryErrorsBase {
val finalName = if (name.isEmpty) s"_$index" else name
finalName -> arg
}
- checkArgs(allArgs)
+ ParameterizedQueryArgumentsValidator(allArgs)
// Collect parameter types used in the query to enforce invariants
var hasNamedParam = false
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
index c6af76564b76..9166bc946b54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
@@ -40,9 +40,9 @@ import org.apache.spark.sql
import org.apache.spark.sql.{AnalysisException, Artifact,
DataSourceRegistration, Encoder, Encoders, ExperimentalMethods, Row,
SparkSessionBuilder, SparkSessionCompanion, SparkSessionExtensions,
SparkSessionExtensionsProvider, UDTFRegistration}
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery,
NameParameterizedQuery, ParameterizedQueryArgumentsValidator,
PosParameterizedQuery, UnresolvedRelation}
import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
CreateArray, CreateMap, CreateNamedStruct, Expression, Literal, MapFromArrays,
MapFromEntries, VariableReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
Expression, Literal}
import org.apache.spark.sql.catalyst.parser.{HybridParameterContext,
NamedParameterContext, ParserInterface, PositionalParameterContext}
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody,
LocalRelation, OneRowRelation, Project, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
@@ -453,27 +453,15 @@ class SparkSession private(
// Analyze the plan to resolve expressions
val analyzed = sessionState.analyzer.execute(fakePlan)
- // Validate: the expression tree must only contain allowed expression
types.
- // This mirrors the validation in BindParameters.checkArgs.
- // We check this BEFORE optimization to catch unsupported functions like
str_to_map.
- def isNotAllowed(expr: Expression): Boolean = expr.exists {
- case _: Literal | _: CreateArray | _: CreateNamedStruct |
- _: CreateMap | _: MapFromArrays | _: MapFromEntries | _:
VariableReference => false
- case a: Alias => isNotAllowed(a.child)
- case _ => true
- }
-
- analyzed.asInstanceOf[Project].projectList.foreach { alias =>
- val optimizedExpr = alias.asInstanceOf[Alias].child
- if (isNotAllowed(optimizedExpr)) {
- // Both modern and legacy modes use INVALID_SQL_ARG for sql() API
argument validation.
- // UNSUPPORTED_EXPR_FOR_PARAMETER is reserved for EXECUTE IMMEDIATE.
- throw new AnalysisException(
- errorClass = "INVALID_SQL_ARG",
- messageParameters = Map("name" -> alias.name),
- origin = optimizedExpr.origin)
- }
+ val expressionsToValidate = analyzed.asInstanceOf[Project].projectList.map
{
+ case alias: Alias =>
+ (alias.name, alias.child)
+ case other =>
+ throw SparkException.internalError(
+ s"Expected an Alias, but got ${other.getClass.getSimpleName}"
+ )
}
+ ParameterizedQueryArgumentsValidator(expressionsToValidate)
// Optimize to constant-fold expressions. After optimization, all allowed
expressions
// should be folded to Literals.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]