This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1f9eaf8f27a7 [SPARK-57175][SQL] Extend nested column pruning to exists
and forall over arrays of structs
1f9eaf8f27a7 is described below
commit 1f9eaf8f27a781a8641a90888c1e6423909064cb
Author: Chao Sun <[email protected]>
AuthorDate: Tue Jun 2 12:18:27 2026 -0700
[SPARK-57175][SQL] Extend nested column pruning to exists and forall over
arrays of structs
### Why are the changes needed?
[SPARK-57175](https://issues.apache.org/jira/browse/SPARK-57175) follows
[SPARK-57022](https://issues.apache.org/jira/browse/SPARK-57022), which added
nested column pruning for `transform` over `array<struct>` inputs. The same
optimization does not currently apply to the `exists` and `forall` higher-order
array functions.
For example:
```sql
SELECT exists(rule_results, rule -> rule.rule_version > 10)
FROM events
```
If `rule_results` contains additional fields, Spark currently retains the
full element struct in the scan schema even though the predicate only reads
`rule_version`. This causes unnecessary Parquet and ORC input reads for wide
array element schemas.
### What changes were proposed in this PR?
- Share the nested-field collection path introduced for `ArrayTransform`
with `ArrayExists` and `ArrayForAll`.
- Rewrite the bound lambda variable type and `GetStructField` ordinals
against the projected element schema after pruning.
- Keep the conservative fallback when a lambda consumes the whole element.
- Add Catalyst and datasource tests covering schema discovery, ordinal
rewrites, predicate-path schema merging, and whole-element fallback.
`ArrayFilter` and `ArraySort` remain out of scope because they return
original input elements and require a different downstream-schema design.
### Does this PR introduce _any_ user-facing change?
Yes. Eligible queries using `exists` or `forall` over arrays of structs can
read a narrower input schema. Query results and SQL APIs are unchanged.
### How was this patch tested?
- `JAVA_HOME=/opt/homebrew/opt/openjdk17/libexec/openjdk.jdk/Contents/Home
PATH=/opt/homebrew/opt/openjdk17/bin:$PATH build/sbt "catalyst/testOnly
org.apache.spark.sql.catalyst.expressions.SchemaPruningSuite" "sql/testOnly
org.apache.spark.sql.execution.datasources.parquet.ParquetV1SchemaPruningSuite
org.apache.spark.sql.execution.datasources.parquet.ParquetV2SchemaPruningSuite
org.apache.spark.sql.execution.datasources.orc.OrcV1SchemaPruningSuite
org.apache.spark.sql.execution.dataso [...]
- `git diff --check`
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Codex (GPT-5)
Closes #56226 from sunchao/dev/chao/codex/spark-array-predicate-pruning.
Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../expressions/ProjectionOverSchema.scala | 63 ++++++++++++++--------
.../sql/catalyst/expressions/SchemaPruning.scala | 37 ++++++++-----
.../catalyst/expressions/SchemaPruningSuite.scala | 39 ++++++++++++++
.../execution/datasources/SchemaPruningSuite.scala | 48 +++++++++++++++++
4 files changed, 151 insertions(+), 36 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 362643016d83..27e014ecef62 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -69,29 +69,16 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
case GetMapValue(child, key) =>
getProjection(child).map { projection => GetMapValue(projection, key) }
case transform @ ArrayTransform(argument, lambda: LambdaFunction) =>
- getProjection(argument).map {
- case projection @ ArrayTypeProjection(projectedElementSchema) =>
- lambda.arguments.headOption match {
- case Some(elementVar: NamedLambdaVariable) =>
- // Pruning fields changes the physical ordinal layout of the
element struct.
- // For example, pruning struct<a, b, c> to struct<a, c> moves
c from ordinal 2
- // to ordinal 1, so rewrite both the variable type and its
field accesses.
- val projectedElementVar = elementVar.copy(dataType =
projectedElementSchema)
- val lambdaProjection =
- ProjectionOverLambdaVariable(elementVar, projectedElementVar)
- val projectedBody = lambda.function.transformDown {
- case lambdaProjection(expr) => expr
- }
- transform.copy(
- argument = projection,
- function = lambda.copy(
- function = projectedBody,
- arguments = projectedElementVar +: lambda.arguments.tail))
- case _ =>
- transform.copy(argument = projection)
- }
- case projection =>
- transform.copy(argument = projection)
+ projectArrayHigherOrderFunction(argument, lambda) { (projection,
projectedLambda) =>
+ transform.copy(argument = projection, function = projectedLambda)
+ }
+ case exists @ ArrayExists(argument, lambda: LambdaFunction, _) =>
+ projectArrayHigherOrderFunction(argument, lambda) { (projection,
projectedLambda) =>
+ exists.copy(argument = projection, function = projectedLambda)
+ }
+ case forall @ ArrayForAll(argument, lambda: LambdaFunction) =>
+ projectArrayHigherOrderFunction(argument, lambda) { (projection,
projectedLambda) =>
+ forall.copy(argument = projection, function = projectedLambda)
}
case GetStructFieldObject(child, field: StructField) =>
getProjection(child).map(p => (p, p.dataType)).map {
@@ -108,6 +95,36 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
None
}
+ private def projectArrayHigherOrderFunction(
+ argument: Expression,
+ lambda: LambdaFunction)(
+ rebuild: (Expression, LambdaFunction) => Expression): Option[Expression]
= {
+ getProjection(argument).map {
+ case projection @ ArrayTypeProjection(projectedElementSchema) =>
+ lambda.arguments.headOption match {
+ case Some(elementVar: NamedLambdaVariable) =>
+ // Pruning fields changes the physical ordinal layout of the
element struct.
+ // For example, pruning struct<a, b, c> to struct<a, c> moves c
from ordinal 2
+ // to ordinal 1, so rewrite both the variable type and its field
accesses.
+ val projectedElementVar = elementVar.copy(dataType =
projectedElementSchema)
+ val lambdaProjection =
+ ProjectionOverLambdaVariable(elementVar, projectedElementVar)
+ val projectedBody = lambda.function.transformDown {
+ case lambdaProjection(expr) => expr
+ }
+ rebuild(
+ projection,
+ lambda.copy(
+ function = projectedBody,
+ arguments = projectedElementVar +: lambda.arguments.tail))
+ case _ =>
+ rebuild(projection, lambda)
+ }
+ case projection =>
+ rebuild(projection, lambda)
+ }
+ }
+
private object ArrayTypeProjection {
def unapply(expr: Expression): Option[StructType] = expr.dataType match {
case ArrayType(projectedElementSchema: StructType, _) =>
Some(projectedElementSchema)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index 2f99dd54f77a..e8aa722bbe23 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -141,18 +141,11 @@ object SchemaPruning extends SQLConfHelper {
private[catalyst] def getRootFields(expr: Expression): Seq[RootField] = {
expr match {
case ArrayTransform(argument, lambda: LambdaFunction) =>
- // Field accesses through the lambda variable are not directly rooted
at the input
- // attribute. Convert them into a projected type for the transform
argument so that
- // physical nested column pruning can see them.
- val nestedRootFields = lambda.arguments.headOption.collect {
- case elementVar: NamedLambdaVariable =>
- getArrayTransformRootField(argument, lambda.function, elementVar)
- }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
- if (nestedRootFields.nonEmpty) {
- nestedRootFields ++ getRootFields(lambda.function)
- } else {
- expr.children.flatMap(getRootFields)
- }
+ getArrayHigherOrderFunctionRootFields(expr, argument, lambda)
+ case ArrayExists(argument, lambda: LambdaFunction, _) =>
+ getArrayHigherOrderFunctionRootFields(expr, argument, lambda)
+ case ArrayForAll(argument, lambda: LambdaFunction) =>
+ getArrayHigherOrderFunctionRootFields(expr, argument, lambda)
case att: Attribute =>
RootField(StructField(att.name, att.dataType, att.nullable,
att.metadata),
derivedFromAtt = true) :: Nil
@@ -175,7 +168,25 @@ object SchemaPruning extends SQLConfHelper {
}
}
- private def getArrayTransformRootField(
+ private def getArrayHigherOrderFunctionRootFields(
+ expr: Expression,
+ argument: Expression,
+ lambda: LambdaFunction): Seq[RootField] = {
+ // Field accesses through the lambda variable are not directly rooted at
the input
+ // attribute. Convert them into a projected type for the array argument so
that
+ // physical nested column pruning can see them.
+ val nestedRootFields = lambda.arguments.headOption.collect {
+ case elementVar: NamedLambdaVariable =>
+ getArrayHigherOrderFunctionRootField(argument, lambda.function,
elementVar)
+ }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
+ if (nestedRootFields.nonEmpty) {
+ nestedRootFields ++ getRootFields(lambda.function)
+ } else {
+ expr.children.flatMap(getRootFields)
+ }
+ }
+
+ private def getArrayHigherOrderFunctionRootField(
argument: Expression,
function: Expression,
elementVar: NamedLambdaVariable): Option[StructField] = {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index 9426ef91349e..af64da7e3820 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -186,4 +186,43 @@ class SchemaPruningSuite extends SparkFunSuite with
SQLHelper {
StructField("event", eventType, nullable = true),
derivedFromAtt = false)))
}
+
+ test("collect nested fields used by ArrayExists and ArrayForAll lambdas") {
+ val elementType = StructType.fromDDL("a int, b int, c int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val argument = GetStructField(event, 0, Some("rules"))
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val predicate = LambdaFunction(
+ GreaterThan(GetStructField(element, 2, Some("c")), Literal(0)),
+ Seq(element))
+
+ Seq(ArrayExists(argument, predicate), ArrayForAll(argument,
predicate)).foreach { function =>
+ val rootFields = SchemaPruning.getRootFields(function)
+ val prunedSchema = SchemaPruning.pruneSchema(
+ StructType(Seq(StructField("event", eventType))),
+ rootFields)
+
+ assert(prunedSchema === StructType.fromDDL(
+ "event struct<rules:array<struct<c:int>>>"))
+ }
+ }
+
+ test("do not collect ArrayExists and ArrayForAll lambda fields when the
whole element is used") {
+ val elementType = StructType.fromDDL("a int, b int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val argument = GetStructField(event, 0, Some("rules"))
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val predicate = LambdaFunction(IsNotNull(element), Seq(element))
+
+ Seq(ArrayExists(argument, predicate), ArrayForAll(argument,
predicate)).foreach { function =>
+ assert(SchemaPruning.getRootFields(function) === Seq(
+ SchemaPruning.RootField(
+ StructField("event", eventType, nullable = true),
+ derivedFromAtt = false)))
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 7f4b83ee342e..2aebf08286e1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -470,6 +470,54 @@ abstract class SchemaPruningSuite
Nil)
}
+ testSchemaPruning("select ArrayExists over nested fields of array of
struct") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(org.apache.spark.sql.functions.exists(
+ col("friends"), friend => friend.getField("last") === "Smith"))
+
+ checkScan(query, "struct<friends:array<struct<last:string>>>")
+ checkAnswer(query, Row(true) :: Row(false) :: Nil)
+ }
+
+ testSchemaPruning("select ArrayForAll over nested fields of array of
struct") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(forall(col("friends"), friend => friend.getField("last") ===
"Smith"))
+
+ checkScan(query, "struct<friends:array<struct<last:string>>>")
+ checkAnswer(query, Row(true) :: Row(true) :: Nil)
+ }
+
+ testSchemaPruning("select nested field with ArrayExists predicate") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .where(org.apache.spark.sql.functions.exists(
+ col("friends"), friend => friend.getField("last") === "Smith"))
+ .select(col("friends").getField("first"))
+
+ checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+ checkAnswer(query, Row(Array("Susan")) :: Nil)
+ }
+
+ testSchemaPruning("do not prune ArrayExists when the whole element is used")
{
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(org.apache.spark.sql.functions.exists(col("friends"), friend =>
friend.isNotNull))
+
+ checkScan(query,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query, Row(true) :: Row(false) :: Nil)
+ }
+
+ testSchemaPruning("do not prune ArrayForAll when the whole element is used")
{
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(forall(col("friends"), friend => friend.isNotNull))
+
+ checkScan(query,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query, Row(true) :: Row(true) :: Nil)
+ }
+
testSchemaPruning("SPARK-34638: nested column prune on generator output") {
val query1 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]