This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 042ad7d0c4ac [SPARK-57176][SQL] Extend nested column pruning through
array-returning functions
042ad7d0c4ac is described below
commit 042ad7d0c4ac1c4d3e9fdeb48e2695fdeb861135
Author: Chao Sun <[email protected]>
AuthorDate: Fri Jun 5 09:57:03 2026 -0700
[SPARK-57176][SQL] Extend nested column pruning through array-returning
functions
### Why are the changes needed?
[SPARK-57176](https://issues.apache.org/jira/browse/SPARK-57176) follows
[SPARK-57022](https://issues.apache.org/jira/browse/SPARK-57022), which added
nested column pruning for `transform` over `array<struct>` inputs.
Array-returning functions still retain the complete input element struct
even when downstream expressions and lambdas only require a subset of nested
fields. For example:
```sql
SELECT filter(friends, friend -> friend.last = 'Smith').first
FROM contacts
```
If `friends` contains `first`, `middle`, and `last`, Spark currently reads
all three fields even though the query only requires `first` and `last`.
### What changes were proposed in this PR?
- Merge downstream result-field requirements with lambda requirements for
`filter` and comparator-based `array_sort`.
- Propagate projected element schemas through `reverse`, `shuffle`,
`slice`, and `array_compact`.
- Rewrite bound lambda variable types and nested field ordinals after
pruning.
- Retain the complete element schema when the whole result is used, when a
lambda consumes the whole element, or when default `array_sort` natural
ordering requires the full struct.
Functions that inspect full element equality or natural ordering remain out
of scope because dropping nested fields could change results.
### Does this PR introduce _any_ user-facing change?
Yes. Eligible queries using array-returning functions over arrays of
structs can read a narrower input schema. Query results and SQL APIs are
unchanged.
### How was this patch tested?
- `JAVA_HOME=/opt/homebrew/opt/openjdk17/libexec/openjdk.jdk/Contents/Home
PATH=/opt/homebrew/opt/openjdk17/bin:$PATH build/sbt "catalyst/testOnly
org.apache.spark.sql.catalyst.expressions.SchemaPruningSuite" "sql/testOnly
org.apache.spark.sql.execution.datasources.parquet.ParquetV1SchemaPruningSuite
org.apache.spark.sql.execution.datasources.parquet.ParquetV2SchemaPruningSuite
org.apache.spark.sql.execution.datasources.orc.OrcV1SchemaPruningSuite
org.apache.spark.sql.execution.dataso [...]
- `JAVA_HOME=/opt/homebrew/opt/openjdk17/libexec/openjdk.jdk/Contents/Home
PATH=/opt/homebrew/opt/openjdk17/bin:$PATH build/sbt catalyst/scalastyle
sql/scalastyle`
- `git diff --check`
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Codex (GPT-5)
Closes #56227 from
sunchao/dev/chao/codex/spark-array-returning-function-pruning.
Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../expressions/ProjectionOverSchema.scala | 77 +++++---
.../sql/catalyst/expressions/SchemaPruning.scala | 75 ++++++-
.../sql/catalyst/expressions/SelectedField.scala | 20 ++
.../catalyst/expressions/SchemaPruningSuite.scala | 219 +++++++++++++++++++++
.../execution/datasources/SchemaPruningSuite.scala | 143 ++++++++++++++
5 files changed, 502 insertions(+), 32 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 27e014ecef62..503edd773933 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -69,17 +69,38 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
case GetMapValue(child, key) =>
getProjection(child).map { projection => GetMapValue(projection, key) }
case transform @ ArrayTransform(argument, lambda: LambdaFunction) =>
- projectArrayHigherOrderFunction(argument, lambda) { (projection,
projectedLambda) =>
- transform.copy(argument = projection, function = projectedLambda)
+ projectArrayHigherOrderFunction(argument, lambda, numElementVariables
= 1) {
+ (projection, projectedLambda) =>
+ transform.copy(argument = projection, function = projectedLambda)
}
case exists @ ArrayExists(argument, lambda: LambdaFunction, _) =>
- projectArrayHigherOrderFunction(argument, lambda) { (projection,
projectedLambda) =>
- exists.copy(argument = projection, function = projectedLambda)
+ projectArrayHigherOrderFunction(argument, lambda, numElementVariables
= 1) {
+ (projection, projectedLambda) =>
+ exists.copy(argument = projection, function = projectedLambda)
}
case forall @ ArrayForAll(argument, lambda: LambdaFunction) =>
- projectArrayHigherOrderFunction(argument, lambda) { (projection,
projectedLambda) =>
- forall.copy(argument = projection, function = projectedLambda)
+ projectArrayHigherOrderFunction(argument, lambda, numElementVariables
= 1) {
+ (projection, projectedLambda) =>
+ forall.copy(argument = projection, function = projectedLambda)
}
+ case filter @ ArrayFilter(argument, lambda: LambdaFunction) =>
+ projectArrayHigherOrderFunction(argument, lambda, numElementVariables
= 1) {
+ (projection, projectedLambda) =>
+ filter.copy(argument = projection, function = projectedLambda)
+ }
+ case sort @ ArraySort(argument, lambda: LambdaFunction, _) =>
+ projectArrayHigherOrderFunction(argument, lambda, numElementVariables
= 2) {
+ (projection, projectedLambda) =>
+ sort.copy(argument = projection, function = projectedLambda)
+ }
+ case reverse @ Reverse(child) =>
+ getProjection(child).map(projection => reverse.copy(child =
projection))
+ case shuffle @ Shuffle(child, _) =>
+ getProjection(child).map(projection => shuffle.copy(child =
projection))
+ case slice @ Slice(x, start, length) if start.foldable &&
length.foldable =>
+ getProjection(x).map(projection => slice.copy(x = projection))
+ case knownNotContainsNull @ KnownNotContainsNull(child) =>
+ getProjection(child).map(projection => knownNotContainsNull.copy(child
= projection))
case GetStructFieldObject(child, field: StructField) =>
getProjection(child).map(p => (p, p.dataType)).map {
case (projection, projSchema: StructType) =>
@@ -97,29 +118,32 @@ case class ProjectionOverSchema(schema: StructType,
output: AttributeSet) {
private def projectArrayHigherOrderFunction(
argument: Expression,
- lambda: LambdaFunction)(
+ lambda: LambdaFunction,
+ numElementVariables: Int)(
rebuild: (Expression, LambdaFunction) => Expression): Option[Expression]
= {
getProjection(argument).map {
case projection @ ArrayTypeProjection(projectedElementSchema) =>
- lambda.arguments.headOption match {
- case Some(elementVar: NamedLambdaVariable) =>
- // Pruning fields changes the physical ordinal layout of the
element struct.
- // For example, pruning struct<a, b, c> to struct<a, c> moves c
from ordinal 2
- // to ordinal 1, so rewrite both the variable type and its field
accesses.
- val projectedElementVar = elementVar.copy(dataType =
projectedElementSchema)
- val lambdaProjection =
- ProjectionOverLambdaVariable(elementVar, projectedElementVar)
- val projectedBody = lambda.function.transformDown {
- case lambdaProjection(expr) => expr
- }
- rebuild(
- projection,
- lambda.copy(
- function = projectedBody,
- arguments = projectedElementVar +: lambda.arguments.tail))
- case _ =>
- rebuild(projection, lambda)
+ val projectedArguments = lambda.arguments.zipWithIndex.map {
+ case (elementVar: NamedLambdaVariable, index) if index <
numElementVariables =>
+ elementVar.copy(dataType = projectedElementSchema)
+ case (argument, _) =>
+ argument
}
+ val projectedBody =
+ lambda.arguments.zip(projectedArguments).foldLeft(lambda.function) {
+ case (body, (elementVar: NamedLambdaVariable, projectedElementVar:
+ NamedLambdaVariable)) if elementVar ne projectedElementVar =>
+ val lambdaProjection =
+ ProjectionOverLambdaVariable(elementVar, projectedElementVar)
+ body.transformDown {
+ case lambdaProjection(expr) => expr
+ }
+ case (body, _) =>
+ body
+ }
+ rebuild(
+ projection,
+ lambda.copy(function = projectedBody, arguments =
projectedArguments))
case projection =>
rebuild(projection, lambda)
}
@@ -135,6 +159,7 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
/**
* Rewrites references rooted at one bound lambda element to use its
projected type and
* recomputes nested field ordinals against each projected struct in the
access path.
+ * Bound lambda references are matched by exprId because they may be
instantiated separately.
* This must support the same access paths collected by `SchemaPruning` for
lambda variables;
* currently both sides support only `GetStructField` chains.
*/
@@ -144,7 +169,7 @@ case class ProjectionOverSchema(schema: StructType, output:
AttributeSet) {
def unapply(expr: Expression): Option[Expression] = project(expr)
private def project(expr: Expression): Option[Expression] = expr match {
- case variable: NamedLambdaVariable if variable.semanticEquals(original)
=>
+ case variable: NamedLambdaVariable if variable.exprId == original.exprId
=>
Some(projected)
case GetStructFieldObject(child, field: StructField) =>
project(child).map { projection =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index e8aa722bbe23..a0f1bec46410 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -149,14 +149,18 @@ object SchemaPruning extends SQLConfHelper {
case att: Attribute =>
RootField(StructField(att.name, att.dataType, att.nullable,
att.metadata),
derivedFromAtt = true) :: Nil
- case SelectedField(field) => RootField(field, derivedFromAtt = false) ::
Nil
+ case SelectedField(field) =>
+ RootField(field, derivedFromAtt = false) +:
+ getArrayReturningHigherOrderFunctionRootFields(expr)
// Root field accesses by `IsNotNull` and `IsNull` are special cases as
the expressions
// don't actually use any nested fields. These root field accesses might
be excluded later
// if there are any nested fields accesses in the query plan.
case IsNotNull(SelectedField(field)) =>
- RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed =
true) :: Nil
+ RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed =
true) +:
+ getArrayReturningHigherOrderFunctionRootFields(expr)
case IsNull(SelectedField(field)) =>
- RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed =
true) :: Nil
+ RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed =
true) +:
+ getArrayReturningHigherOrderFunctionRootFields(expr)
case IsNotNull(_: Attribute) | IsNull(_: Attribute) =>
expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed =
true))
case s: SubqueryExpression =>
@@ -180,7 +184,8 @@ object SchemaPruning extends SQLConfHelper {
getArrayHigherOrderFunctionRootField(argument, lambda.function,
elementVar)
}.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
if (nestedRootFields.nonEmpty) {
- nestedRootFields ++ getRootFields(lambda.function)
+ nestedRootFields ++ getRootFields(lambda.function) ++
+ getArrayReturningHigherOrderFunctionRootFields(argument)
} else {
expr.children.flatMap(getRootFields)
}
@@ -208,6 +213,58 @@ object SchemaPruning extends SQLConfHelper {
}
}
+ private def getArrayReturningHigherOrderFunctionRootFields(expr:
Expression): Seq[RootField] = {
+ expr match {
+ case ArrayFilter(argument, lambda: LambdaFunction) =>
+ getArrayReturningHigherOrderFunctionRootFields(argument, lambda,
numElementVariables = 1)
+ case ArraySort(argument, lambda: LambdaFunction,
allowNullComparisonResult)
+ if !allowNullComparisonResult && lambda.function.nullable =>
+ // The strict null-comparator error includes the compared values, so
pruning their
+ // element fields would change the observable error parameters.
+ getRootFields(argument) ++ getRootFields(lambda.function)
+ case ArraySort(argument, lambda: LambdaFunction, _) =>
+ getArrayReturningHigherOrderFunctionRootFields(argument, lambda,
numElementVariables = 2)
+ case _ =>
+ expr.children.flatMap(getArrayReturningHigherOrderFunctionRootFields)
+ }
+ }
+
+ private def getArrayReturningHigherOrderFunctionRootFields(
+ argument: Expression,
+ lambda: LambdaFunction,
+ numElementVariables: Int): Seq[RootField] = {
+ val nestedRootFields = argument.dataType match {
+ case ArrayType(_: StructType, containsNull) =>
+ val elementVariables =
lambda.arguments.take(numElementVariables).collect {
+ case elementVar: NamedLambdaVariable => elementVar
+ }
+ val selectedFields =
elementVariables.map(collectLambdaVariableFields(lambda.function, _))
+ if (elementVariables.length == numElementVariables &&
selectedFields.forall(_.isDefined)) {
+ val fields = selectedFields.flatten.flatten
+ if (fields.nonEmpty) {
+ val mergedElementSchema = fields
+ .map(field => StructType(Array(field)))
+ .reduceLeft(_ merge _)
+ SelectedField.withDataType(
+ argument,
+ ArrayType(mergedElementSchema, containsNull)).toSeq match {
+ case Seq() => getRootFields(argument).map(_.field)
+ case fields => fields
+ }
+ } else {
+ Seq.empty
+ }
+ } else {
+ getRootFields(argument).map(_.field)
+ }
+ case _ =>
+ getRootFields(argument).map(_.field)
+ }
+ nestedRootFields.map(field => RootField(field, derivedFromAtt = false)) ++
+ getRootFields(lambda.function) ++
+ getArrayReturningHigherOrderFunctionRootFields(argument)
+ }
+
/**
* Collects statically identifiable nested fields read from `elementVar`.
*
@@ -216,6 +273,8 @@ object SchemaPruning extends SQLConfHelper {
* means the full element is required somewhere (for example, `x =>
struct(x.a, x)`), so it is
* not safe to prune the element struct.
*
+ * Bound lambda references are matched by exprId because they may be
instantiated separately.
+ *
* Currently only `GetStructField` chains rooted at `elementVar` are
collected; array or map
* traversal within the lambda conservatively requires the full element.
Keep this set of
* supported paths in sync with `ProjectionOverLambdaVariable` in
`ProjectionOverSchema`.
@@ -224,9 +283,13 @@ object SchemaPruning extends SQLConfHelper {
expr: Expression,
elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
expr match {
- case LambdaVariableField(field, variable) if
variable.semanticEquals(elementVar) =>
+ case LambdaVariableField(field, variable) if variable.exprId ==
elementVar.exprId =>
Some(field :: Nil)
- case variable: NamedLambdaVariable if
variable.semanticEquals(elementVar) =>
+ case IsNotNull(variable: NamedLambdaVariable) if variable.exprId ==
elementVar.exprId =>
+ Some(Seq.empty)
+ case IsNull(variable: NamedLambdaVariable) if variable.exprId ==
elementVar.exprId =>
+ Some(Seq.empty)
+ case variable: NamedLambdaVariable if variable.exprId ==
elementVar.exprId =>
None
case _ =>
expr.children.foldLeft(Option(Seq.empty[StructField])) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
index e36224e7d5c1..69dfdbfc9a08 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
@@ -163,6 +163,26 @@ object SelectedField {
val opt = dataTypeOpt.map(dt => MapType(keyType, dt,
valueContainsNull))
selectField(left, opt)
}
+ case ArrayFilter(argument, _: LambdaFunction) =>
+ selectField(argument, dataTypeOpt)
+ case ArraySort(argument, _: LambdaFunction, _) =>
+ selectField(argument, dataTypeOpt)
+ case Reverse(child) =>
+ selectField(child, dataTypeOpt)
+ case Shuffle(child, _) =>
+ selectField(child, dataTypeOpt)
+ case Slice(x, start, length) if start.foldable && length.foldable =>
+ selectField(x, dataTypeOpt)
+ case KnownNotContainsNull(child) =>
+ val ArrayType(_, containsNull) = child.dataType
+ val opt = dataTypeOpt.map {
+ case ArrayType(dataType, _) => ArrayType(dataType, containsNull)
+ case x =>
+ // This should not happen.
+ throw QueryCompilationErrors.dataTypeUnsupportedByClassError(
+ x, "KnownNotContainsNull")
+ }
+ selectField(child, opt)
case _ =>
None
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index af64da7e3820..ae0257522727 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import java.util.concurrent.atomic.AtomicReference
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
import org.apache.spark.sql.types._
@@ -209,6 +213,150 @@ class SchemaPruningSuite extends SparkFunSuite with
SQLHelper {
}
}
+ test("merge returned and lambda fields for array higher-order functions") {
+ val elementType = StructType.fromDDL("a int, b int, c int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val argument = GetStructField(event, 0, Some("rules"))
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val predicate = LambdaFunction(
+ GreaterThan(GetStructField(element, 2, Some("c")), Literal(0)),
+ Seq(element))
+ val left = NamedLambdaVariable("left", elementType, nullable = true)
+ val right = NamedLambdaVariable("right", elementType, nullable = true)
+ val comparator = LambdaFunction(
+ Coalesce(Seq(
+ Subtract(
+ GetStructField(left, 2, Some("c")),
+ GetStructField(right, 2, Some("c"))),
+ Literal(0))),
+ Seq(left, right))
+
+ Seq(ArrayFilter(argument, predicate), ArraySort(argument,
comparator)).foreach { function =>
+ val selected = GetArrayStructFields(
+ function,
+ elementType(0),
+ ordinal = 0,
+ numFields = elementType.length,
+ containsNull = true)
+ val rootFields = SchemaPruning.getRootFields(selected)
+ val prunedSchema = SchemaPruning.pruneSchema(
+ StructType(Seq(StructField("event", eventType))),
+ rootFields)
+
+ assert(prunedSchema === StructType.fromDDL(
+ "event struct<rules:array<struct<a:int,c:int>>>"))
+ }
+ }
+
+ test("match separately instantiated array lambda variables by exprId") {
+ val elementType = StructType.fromDDL("a int, b int, c int")
+ val sourceSchema = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val rules = AttributeReference("rules", ArrayType(elementType,
containsNull = true))()
+
+ def selectFirstField(function: Expression): GetArrayStructFields = {
+ GetArrayStructFields(
+ function,
+ elementType(0),
+ ordinal = 0,
+ numFields = elementType.length,
+ containsNull = true)
+ }
+
+ def evaluateSelectedFields(selected: GetArrayStructFields): Seq[Int] = {
+ val rootFields = SchemaPruning.identifyRootFields(Seq(Alias(selected,
"out")()), Seq.empty)
+ val prunedSchema = SchemaPruning.pruneSchema(sourceSchema, rootFields)
+ assert(prunedSchema === StructType.fromDDL("rules
array<struct<a:int,c:int>>"))
+
+ val projected = ProjectionOverSchema(prunedSchema,
AttributeSet(Seq(rules)))
+ .unapply(selected).get
+ val bound = BindReferences.bindReference(projected, Seq(rules))
+ val array = new GenericArrayData(Array[Any](InternalRow(1, 3),
InternalRow(2, -1)))
+ val result = bound.eval(InternalRow(array)).asInstanceOf[ArrayData]
+ (0 until result.numElements()).map(result.getInt)
+ }
+
+ val filterArgument = NamedLambdaVariable("x", elementType, nullable = true)
+ val filterReference = filterArgument.copy(value = new
AtomicReference[Any]())
+ val filter = ArrayFilter(
+ rules,
+ LambdaFunction(
+ GreaterThan(GetStructField(filterReference, 2, Some("c")), Literal(0)),
+ Seq(filterArgument)))
+ assert(evaluateSelectedFields(selectFirstField(filter)) === Seq(1))
+
+ val leftArgument = NamedLambdaVariable("left", elementType, nullable =
true)
+ val rightArgument = NamedLambdaVariable("right", elementType, nullable =
true)
+ val leftReference = leftArgument.copy(value = new AtomicReference[Any]())
+ val rightReference = rightArgument.copy(value = new AtomicReference[Any]())
+ val sort = ArraySort(
+ rules,
+ LambdaFunction(
+ Coalesce(Seq(
+ Subtract(
+ GetStructField(leftReference, 2, Some("c")),
+ GetStructField(rightReference, 2, Some("c"))),
+ Literal(0))),
+ Seq(leftArgument, rightArgument)),
+ allowNullComparisonResult = false)
+ assert(evaluateSelectedFields(selectFirstField(sort)) === Seq(2, 1))
+ }
+
+ test("retain full nested array elements for array-returning higher-order
functions") {
+ val structType = StructType.fromDDL("a int, c int")
+ val arrayType = ArrayType(structType, containsNull = true)
+ val nestedArrayType = ArrayType(arrayType, containsNull = true)
+ val sourceSchema = StructType(Seq(StructField("rules", nestedArrayType)))
+ val rules = AttributeReference("rules", nestedArrayType)()
+
+ def selectFirstField(function: Expression): GetArrayStructFields = {
+ GetArrayStructFields(
+ GetArrayItem(function, Literal(0)),
+ structType(0),
+ ordinal = 0,
+ numFields = structType.length,
+ containsNull = true)
+ }
+
+ val element = NamedLambdaVariable("x", arrayType, nullable = true)
+ val elementC = GetArrayStructFields(
+ element,
+ structType(1),
+ ordinal = 1,
+ numFields = structType.length,
+ containsNull = true)
+ val filter = ArrayFilter(
+ rules,
+ LambdaFunction(GreaterThan(GetArrayItem(elementC, Literal(0)),
Literal(0)), Seq(element)))
+
+ val left = NamedLambdaVariable("left", arrayType, nullable = true)
+ val right = NamedLambdaVariable("right", arrayType, nullable = true)
+ def firstC(variable: NamedLambdaVariable): Expression = {
+ GetArrayItem(
+ GetArrayStructFields(
+ variable,
+ structType(1),
+ ordinal = 1,
+ numFields = structType.length,
+ containsNull = true),
+ Literal(0))
+ }
+ val sort = ArraySort(
+ rules,
+ LambdaFunction(
+ Coalesce(Seq(Subtract(firstC(left), firstC(right)), Literal(0))),
+ Seq(left, right)),
+ allowNullComparisonResult = false)
+
+ Seq(filter, sort).foreach { function =>
+ val selected = Alias(selectFirstField(function), "out")()
+ val rootFields = SchemaPruning.identifyRootFields(Seq(selected),
Seq.empty)
+ assert(SchemaPruning.pruneSchema(sourceSchema, rootFields) ===
sourceSchema)
+ }
+ }
+
test("do not collect ArrayExists and ArrayForAll lambda fields when the
whole element is used") {
val elementType = StructType.fromDDL("a int, b int")
val eventType = StructType(Seq(
@@ -225,4 +373,75 @@ class SchemaPruningSuite extends SparkFunSuite with
SQLHelper {
derivedFromAtt = false)))
}
}
+
+ test("do not prune ArrayFilter when the whole result is used") {
+ val elementType = StructType.fromDDL("a int, b int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val filtered = ArrayFilter(
+ GetStructField(event, 0, Some("rules")),
+ LambdaFunction(
+ GreaterThan(GetStructField(element, 0, Some("a")), Literal(0)),
+ Seq(element)))
+
+ val rootFields = SchemaPruning.getRootFields(filtered)
+
+ assert(rootFields.contains(
+ SchemaPruning.RootField(
+ StructField("event", eventType, nullable = true),
+ derivedFromAtt = false)))
+ }
+
+ test("do not prune strict ArraySort when the comparator can return null") {
+ val elementType = StructType.fromDDL("a int, b int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val argument = GetStructField(event, 0, Some("rules"))
+ val left = NamedLambdaVariable("left", elementType, nullable = true)
+ val right = NamedLambdaVariable("right", elementType, nullable = true)
+ val comparator = LambdaFunction(Literal.create(null, IntegerType),
Seq(left, right))
+ val sorted = ArraySort(argument, comparator, allowNullComparisonResult =
false)
+ val selected = GetArrayStructFields(
+ sorted,
+ elementType(0),
+ ordinal = 0,
+ numFields = elementType.length,
+ containsNull = true)
+
+ val rootFields = SchemaPruning.getRootFields(selected)
+ val prunedSchema = SchemaPruning.pruneSchema(
+ StructType(Seq(StructField("event", eventType))),
+ rootFields)
+
+ assert(prunedSchema === StructType(Seq(StructField("event", eventType))))
+ }
+
+ test("retain input array nullability when pruning through
KnownNotContainsNull") {
+ val elementType = StructType.fromDDL("a int, b int")
+ val eventType = StructType(Seq(
+ StructField("rules", ArrayType(elementType, containsNull = true))))
+ val event = AttributeReference("event", eventType)()
+ val element = NamedLambdaVariable("x", elementType, nullable = true)
+ val compacted = KnownNotContainsNull(ArrayFilter(
+ GetStructField(event, 0, Some("rules")),
+ LambdaFunction(IsNotNull(element), Seq(element))))
+ val selected = GetArrayStructFields(
+ compacted,
+ elementType(0),
+ ordinal = 0,
+ numFields = elementType.length,
+ containsNull = false)
+
+ val rootFields = SchemaPruning.getRootFields(selected)
+ val prunedSchema = SchemaPruning.pruneSchema(
+ StructType(Seq(StructField("event", eventType))),
+ rootFields)
+ val prunedEventType =
prunedSchema("event").dataType.asInstanceOf[StructType]
+
+ assert(prunedEventType("rules").dataType ===
+ ArrayType(StructType.fromDDL("a int"), containsNull = true))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 2aebf08286e1..6b8f3495f4a0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -518,6 +518,149 @@ abstract class SchemaPruningSuite
checkAnswer(query, Row(true) :: Row(true) :: Nil)
}
+ testSchemaPruning("select nested field returned by ArrayFilter") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(org.apache.spark.sql.functions.filter(
+ col("friends"), friend => friend.getField("last") ===
"Smith").getField("first"))
+
+ checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array("Susan")) ::
+ Row(Array.empty[String]) ::
+ Nil)
+ }
+
+ testSchemaPruning("do not prune ArrayFilter when the whole result is used") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(org.apache.spark.sql.functions.filter(
+ col("friends"), friend => friend.getField("last") === "Smith"))
+
+ checkScan(query,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array(Row("Susan", "Z.", "Smith"))) ::
+ Row(Array.empty[Row]) ::
+ Nil)
+ }
+
+ testSchemaPruning("select nested field returned by array functions under
null checks") {
+ val expressions = Seq(
+ org.apache.spark.sql.functions.filter(
+ col("friends"), friend => friend.getField("last") === "Smith")
+ .getField("first").isNotNull,
+ array_sort(col("friends"), (left, right) =>
+ when(left.getField("last") < right.getField("last"), -1)
+ .when(left.getField("last") > right.getField("last"), 1)
+ .otherwise(0)).getField("first").isNull)
+
+ expressions.foreach { expression =>
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(expression)
+
+ checkScan(query,
"struct<friends:array<struct<first:string,last:string>>>")
+ }
+ }
+
+ testSchemaPruning("select nested field returned by ArraySort") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(array_sort(col("friends"), (left, right) =>
+ when(left.getField("last") < right.getField("last"), -1)
+ .when(left.getField("last") > right.getField("last"), 1)
+ .otherwise(0)).getField("first"))
+
+ checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array("Susan")) ::
+ Row(Array.empty[String]) ::
+ Nil)
+ }
+
+ testSchemaPruning("do not prune default ArraySort when selecting a nested
field") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(array_sort(col("friends")).getField("first"))
+
+ checkScan(query,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array("Susan")) ::
+ Row(Array.empty[String]) ::
+ Nil)
+ }
+
+ testSchemaPruning("select nested field returned by Array wrappers") {
+ val queries = Seq(
+ reverse(col("friends")).getField("first"),
+ shuffle(col("friends")).getField("first"),
+ slice(col("friends"), 1, 1).getField("first"))
+
+ queries.foreach { expression =>
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(expression)
+
+ checkScan(query, "struct<friends:array<struct<first:string>>>")
+ checkAnswer(query,
+ Row(Array("Susan")) ::
+ Row(Array.empty[String]) ::
+ Nil)
+ }
+ }
+
+ testSchemaPruning("do not prune through Slice with non-foldable bounds") {
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(
+ slice(col("friends"), col("employer.id") + 1,
lit(1)).getField("first"),
+ col("employer.company.name"))
+
+ checkScan(query,
+ "struct<friends:array<struct<first:string,middle:string,last:string>>," +
+ "employer:struct<id:int,company:struct<name:string>>>")
+ checkAnswer(query,
+ Row(Array("Susan"), "abc") ::
+ Row(Array.empty[String], null) ::
+ Nil)
+ }
+
+ testSchemaPruning("select ArrayTransform over array-returning higher-order
functions") {
+ val expressions = Seq(
+ transform(
+ org.apache.spark.sql.functions.filter(
+ col("friends"), friend => friend.getField("last") === "Smith"),
+ friend => friend.getField("first")),
+ transform(
+ array_sort(col("friends"), (left, right) =>
+ when(left.getField("last") < right.getField("last"), -1)
+ .when(left.getField("last") > right.getField("last"), 1)
+ .otherwise(0)),
+ friend => friend.getField("first")))
+
+ expressions.foreach { expression =>
+ val query = spark.table("contacts")
+ .where("p = 1")
+ .select(expression)
+
+ checkScan(query,
"struct<friends:array<struct<first:string,last:string>>>")
+ checkAnswer(query,
+ Row(Array("Susan")) ::
+ Row(Array.empty[String]) ::
+ Nil)
+ }
+ }
+
+ testSchemaPruning("select nested field returned by ArrayCompact with null
elements") {
+ withDataSourceTable(organizations, "organizations") {
+ val query = spark.table("organizations")
+ .select(array_compact(col("team.members")).getField("id"))
+
+ checkScan(query, "struct<team:struct<members:array<struct<id:int>>>>")
+ checkAnswer(query, Row(Array(1, 0)) :: Nil)
+ }
+ }
+
testSchemaPruning("SPARK-34638: nested column prune on generator output") {
val query1 = spark.table("contacts")
.select(explode(col("friends")).as("friend"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]