This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 042ad7d0c4ac [SPARK-57176][SQL] Extend nested column pruning through 
array-returning functions
042ad7d0c4ac is described below

commit 042ad7d0c4ac1c4d3e9fdeb48e2695fdeb861135
Author: Chao Sun <[email protected]>
AuthorDate: Fri Jun 5 09:57:03 2026 -0700

    [SPARK-57176][SQL] Extend nested column pruning through array-returning 
functions
    
    ### Why are the changes needed?
    
    [SPARK-57176](https://issues.apache.org/jira/browse/SPARK-57176) follows 
[SPARK-57022](https://issues.apache.org/jira/browse/SPARK-57022), which added 
nested column pruning for `transform` over `array<struct>` inputs.
    
    Array-returning functions still retain the complete input element struct 
even when downstream expressions and lambdas only require a subset of nested 
fields. For example:
    
    ```sql
    SELECT filter(friends, friend -> friend.last = 'Smith').first
    FROM contacts
    ```
    
    If `friends` contains `first`, `middle`, and `last`, Spark currently reads 
all three fields even though the query only requires `first` and `last`.
    
    ### What changes were proposed in this PR?
    
    - Merge downstream result-field requirements with lambda requirements for 
`filter` and comparator-based `array_sort`.
    - Propagate projected element schemas through `reverse`, `shuffle`, 
`slice`, and `array_compact`.
    - Rewrite bound lambda variable types and nested field ordinals after 
pruning.
    - Retain the complete element schema when the whole result is used, when a 
lambda consumes the whole element, or when default `array_sort` natural 
ordering requires the full struct.
    
    Functions that inspect full element equality or natural ordering remain out 
of scope because dropping nested fields could change results.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Eligible queries using array-returning functions over arrays of 
structs can read a narrower input schema. Query results and SQL APIs are 
unchanged.
    
    ### How was this patch tested?
    
    - `JAVA_HOME=/opt/homebrew/opt/openjdk17/libexec/openjdk.jdk/Contents/Home 
PATH=/opt/homebrew/opt/openjdk17/bin:$PATH build/sbt "catalyst/testOnly 
org.apache.spark.sql.catalyst.expressions.SchemaPruningSuite" "sql/testOnly 
org.apache.spark.sql.execution.datasources.parquet.ParquetV1SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.parquet.ParquetV2SchemaPruningSuite 
org.apache.spark.sql.execution.datasources.orc.OrcV1SchemaPruningSuite 
org.apache.spark.sql.execution.dataso [...]
    - `JAVA_HOME=/opt/homebrew/opt/openjdk17/libexec/openjdk.jdk/Contents/Home 
PATH=/opt/homebrew/opt/openjdk17/bin:$PATH build/sbt catalyst/scalastyle 
sql/scalastyle`
    - `git diff --check`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Codex (GPT-5)
    
    Closes #56227 from 
sunchao/dev/chao/codex/spark-array-returning-function-pruning.
    
    Authored-by: Chao Sun <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
---
 .../expressions/ProjectionOverSchema.scala         |  77 +++++---
 .../sql/catalyst/expressions/SchemaPruning.scala   |  75 ++++++-
 .../sql/catalyst/expressions/SelectedField.scala   |  20 ++
 .../catalyst/expressions/SchemaPruningSuite.scala  | 219 +++++++++++++++++++++
 .../execution/datasources/SchemaPruningSuite.scala | 143 ++++++++++++++
 5 files changed, 502 insertions(+), 32 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index 27e014ecef62..503edd773933 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -69,17 +69,38 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
       case GetMapValue(child, key) =>
         getProjection(child).map { projection => GetMapValue(projection, key) }
       case transform @ ArrayTransform(argument, lambda: LambdaFunction) =>
-        projectArrayHigherOrderFunction(argument, lambda) { (projection, 
projectedLambda) =>
-          transform.copy(argument = projection, function = projectedLambda)
+        projectArrayHigherOrderFunction(argument, lambda, numElementVariables 
= 1) {
+          (projection, projectedLambda) =>
+            transform.copy(argument = projection, function = projectedLambda)
         }
       case exists @ ArrayExists(argument, lambda: LambdaFunction, _) =>
-        projectArrayHigherOrderFunction(argument, lambda) { (projection, 
projectedLambda) =>
-          exists.copy(argument = projection, function = projectedLambda)
+        projectArrayHigherOrderFunction(argument, lambda, numElementVariables 
= 1) {
+          (projection, projectedLambda) =>
+            exists.copy(argument = projection, function = projectedLambda)
         }
       case forall @ ArrayForAll(argument, lambda: LambdaFunction) =>
-        projectArrayHigherOrderFunction(argument, lambda) { (projection, 
projectedLambda) =>
-          forall.copy(argument = projection, function = projectedLambda)
+        projectArrayHigherOrderFunction(argument, lambda, numElementVariables 
= 1) {
+          (projection, projectedLambda) =>
+            forall.copy(argument = projection, function = projectedLambda)
         }
+      case filter @ ArrayFilter(argument, lambda: LambdaFunction) =>
+        projectArrayHigherOrderFunction(argument, lambda, numElementVariables 
= 1) {
+          (projection, projectedLambda) =>
+            filter.copy(argument = projection, function = projectedLambda)
+        }
+      case sort @ ArraySort(argument, lambda: LambdaFunction, _) =>
+        projectArrayHigherOrderFunction(argument, lambda, numElementVariables 
= 2) {
+          (projection, projectedLambda) =>
+            sort.copy(argument = projection, function = projectedLambda)
+        }
+      case reverse @ Reverse(child) =>
+        getProjection(child).map(projection => reverse.copy(child = 
projection))
+      case shuffle @ Shuffle(child, _) =>
+        getProjection(child).map(projection => shuffle.copy(child = 
projection))
+      case slice @ Slice(x, start, length) if start.foldable && 
length.foldable =>
+        getProjection(x).map(projection => slice.copy(x = projection))
+      case knownNotContainsNull @ KnownNotContainsNull(child) =>
+        getProjection(child).map(projection => knownNotContainsNull.copy(child 
= projection))
       case GetStructFieldObject(child, field: StructField) =>
         getProjection(child).map(p => (p, p.dataType)).map {
           case (projection, projSchema: StructType) =>
@@ -97,29 +118,32 @@ case class ProjectionOverSchema(schema: StructType, 
output: AttributeSet) {
 
   private def projectArrayHigherOrderFunction(
       argument: Expression,
-      lambda: LambdaFunction)(
+      lambda: LambdaFunction,
+      numElementVariables: Int)(
       rebuild: (Expression, LambdaFunction) => Expression): Option[Expression] 
= {
     getProjection(argument).map {
       case projection @ ArrayTypeProjection(projectedElementSchema) =>
-        lambda.arguments.headOption match {
-          case Some(elementVar: NamedLambdaVariable) =>
-            // Pruning fields changes the physical ordinal layout of the 
element struct.
-            // For example, pruning struct<a, b, c> to struct<a, c> moves c 
from ordinal 2
-            // to ordinal 1, so rewrite both the variable type and its field 
accesses.
-            val projectedElementVar = elementVar.copy(dataType = 
projectedElementSchema)
-            val lambdaProjection =
-              ProjectionOverLambdaVariable(elementVar, projectedElementVar)
-            val projectedBody = lambda.function.transformDown {
-              case lambdaProjection(expr) => expr
-            }
-            rebuild(
-              projection,
-              lambda.copy(
-                function = projectedBody,
-                arguments = projectedElementVar +: lambda.arguments.tail))
-          case _ =>
-            rebuild(projection, lambda)
+        val projectedArguments = lambda.arguments.zipWithIndex.map {
+          case (elementVar: NamedLambdaVariable, index) if index < 
numElementVariables =>
+            elementVar.copy(dataType = projectedElementSchema)
+          case (argument, _) =>
+            argument
         }
+        val projectedBody =
+          lambda.arguments.zip(projectedArguments).foldLeft(lambda.function) {
+            case (body, (elementVar: NamedLambdaVariable, projectedElementVar:
+                NamedLambdaVariable)) if elementVar ne projectedElementVar =>
+              val lambdaProjection =
+                ProjectionOverLambdaVariable(elementVar, projectedElementVar)
+              body.transformDown {
+                case lambdaProjection(expr) => expr
+              }
+            case (body, _) =>
+              body
+          }
+        rebuild(
+          projection,
+          lambda.copy(function = projectedBody, arguments = 
projectedArguments))
       case projection =>
         rebuild(projection, lambda)
     }
@@ -135,6 +159,7 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
   /**
    * Rewrites references rooted at one bound lambda element to use its 
projected type and
    * recomputes nested field ordinals against each projected struct in the 
access path.
+   * Bound lambda references are matched by exprId because they may be 
instantiated separately.
    * This must support the same access paths collected by `SchemaPruning` for 
lambda variables;
    * currently both sides support only `GetStructField` chains.
    */
@@ -144,7 +169,7 @@ case class ProjectionOverSchema(schema: StructType, output: 
AttributeSet) {
     def unapply(expr: Expression): Option[Expression] = project(expr)
 
     private def project(expr: Expression): Option[Expression] = expr match {
-      case variable: NamedLambdaVariable if variable.semanticEquals(original) 
=>
+      case variable: NamedLambdaVariable if variable.exprId == original.exprId 
=>
         Some(projected)
       case GetStructFieldObject(child, field: StructField) =>
         project(child).map { projection =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index e8aa722bbe23..a0f1bec46410 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -149,14 +149,18 @@ object SchemaPruning extends SQLConfHelper {
       case att: Attribute =>
         RootField(StructField(att.name, att.dataType, att.nullable, 
att.metadata),
           derivedFromAtt = true) :: Nil
-      case SelectedField(field) => RootField(field, derivedFromAtt = false) :: 
Nil
+      case SelectedField(field) =>
+        RootField(field, derivedFromAtt = false) +:
+          getArrayReturningHigherOrderFunctionRootFields(expr)
       // Root field accesses by `IsNotNull` and `IsNull` are special cases as 
the expressions
       // don't actually use any nested fields. These root field accesses might 
be excluded later
       // if there are any nested fields accesses in the query plan.
       case IsNotNull(SelectedField(field)) =>
-        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = 
true) :: Nil
+        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = 
true) +:
+          getArrayReturningHigherOrderFunctionRootFields(expr)
       case IsNull(SelectedField(field)) =>
-        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = 
true) :: Nil
+        RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = 
true) +:
+          getArrayReturningHigherOrderFunctionRootFields(expr)
       case IsNotNull(_: Attribute) | IsNull(_: Attribute) =>
         
expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed = 
true))
       case s: SubqueryExpression =>
@@ -180,7 +184,8 @@ object SchemaPruning extends SQLConfHelper {
         getArrayHigherOrderFunctionRootField(argument, lambda.function, 
elementVar)
     }.flatten.toSeq.map(field => RootField(field, derivedFromAtt = false))
     if (nestedRootFields.nonEmpty) {
-      nestedRootFields ++ getRootFields(lambda.function)
+      nestedRootFields ++ getRootFields(lambda.function) ++
+        getArrayReturningHigherOrderFunctionRootFields(argument)
     } else {
       expr.children.flatMap(getRootFields)
     }
@@ -208,6 +213,58 @@ object SchemaPruning extends SQLConfHelper {
     }
   }
 
+  private def getArrayReturningHigherOrderFunctionRootFields(expr: 
Expression): Seq[RootField] = {
+    expr match {
+      case ArrayFilter(argument, lambda: LambdaFunction) =>
+        getArrayReturningHigherOrderFunctionRootFields(argument, lambda, 
numElementVariables = 1)
+      case ArraySort(argument, lambda: LambdaFunction, 
allowNullComparisonResult)
+          if !allowNullComparisonResult && lambda.function.nullable =>
+        // The strict null-comparator error includes the compared values, so 
pruning their
+        // element fields would change the observable error parameters.
+        getRootFields(argument) ++ getRootFields(lambda.function)
+      case ArraySort(argument, lambda: LambdaFunction, _) =>
+        getArrayReturningHigherOrderFunctionRootFields(argument, lambda, 
numElementVariables = 2)
+      case _ =>
+        expr.children.flatMap(getArrayReturningHigherOrderFunctionRootFields)
+    }
+  }
+
+  private def getArrayReturningHigherOrderFunctionRootFields(
+      argument: Expression,
+      lambda: LambdaFunction,
+      numElementVariables: Int): Seq[RootField] = {
+    val nestedRootFields = argument.dataType match {
+      case ArrayType(_: StructType, containsNull) =>
+        val elementVariables = 
lambda.arguments.take(numElementVariables).collect {
+          case elementVar: NamedLambdaVariable => elementVar
+        }
+        val selectedFields = 
elementVariables.map(collectLambdaVariableFields(lambda.function, _))
+        if (elementVariables.length == numElementVariables && 
selectedFields.forall(_.isDefined)) {
+          val fields = selectedFields.flatten.flatten
+          if (fields.nonEmpty) {
+            val mergedElementSchema = fields
+              .map(field => StructType(Array(field)))
+              .reduceLeft(_ merge _)
+            SelectedField.withDataType(
+              argument,
+              ArrayType(mergedElementSchema, containsNull)).toSeq match {
+              case Seq() => getRootFields(argument).map(_.field)
+              case fields => fields
+            }
+          } else {
+            Seq.empty
+          }
+        } else {
+          getRootFields(argument).map(_.field)
+        }
+      case _ =>
+        getRootFields(argument).map(_.field)
+    }
+    nestedRootFields.map(field => RootField(field, derivedFromAtt = false)) ++
+      getRootFields(lambda.function) ++
+      getArrayReturningHigherOrderFunctionRootFields(argument)
+  }
+
   /**
    * Collects statically identifiable nested fields read from `elementVar`.
    *
@@ -216,6 +273,8 @@ object SchemaPruning extends SQLConfHelper {
    * means the full element is required somewhere (for example, `x => 
struct(x.a, x)`), so it is
    * not safe to prune the element struct.
    *
+   * Bound lambda references are matched by exprId because they may be 
instantiated separately.
+   *
    * Currently only `GetStructField` chains rooted at `elementVar` are 
collected; array or map
    * traversal within the lambda conservatively requires the full element. 
Keep this set of
    * supported paths in sync with `ProjectionOverLambdaVariable` in 
`ProjectionOverSchema`.
@@ -224,9 +283,13 @@ object SchemaPruning extends SQLConfHelper {
       expr: Expression,
       elementVar: NamedLambdaVariable): Option[Seq[StructField]] = {
     expr match {
-      case LambdaVariableField(field, variable) if 
variable.semanticEquals(elementVar) =>
+      case LambdaVariableField(field, variable) if variable.exprId == 
elementVar.exprId =>
         Some(field :: Nil)
-      case variable: NamedLambdaVariable if 
variable.semanticEquals(elementVar) =>
+      case IsNotNull(variable: NamedLambdaVariable) if variable.exprId == 
elementVar.exprId =>
+        Some(Seq.empty)
+      case IsNull(variable: NamedLambdaVariable) if variable.exprId == 
elementVar.exprId =>
+        Some(Seq.empty)
+      case variable: NamedLambdaVariable if variable.exprId == 
elementVar.exprId =>
         None
       case _ =>
         expr.children.foldLeft(Option(Seq.empty[StructField])) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
index e36224e7d5c1..69dfdbfc9a08 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala
@@ -163,6 +163,26 @@ object SelectedField {
             val opt = dataTypeOpt.map(dt => MapType(keyType, dt, 
valueContainsNull))
             selectField(left, opt)
         }
+      case ArrayFilter(argument, _: LambdaFunction) =>
+        selectField(argument, dataTypeOpt)
+      case ArraySort(argument, _: LambdaFunction, _) =>
+        selectField(argument, dataTypeOpt)
+      case Reverse(child) =>
+        selectField(child, dataTypeOpt)
+      case Shuffle(child, _) =>
+        selectField(child, dataTypeOpt)
+      case Slice(x, start, length) if start.foldable && length.foldable =>
+        selectField(x, dataTypeOpt)
+      case KnownNotContainsNull(child) =>
+        val ArrayType(_, containsNull) = child.dataType
+        val opt = dataTypeOpt.map {
+          case ArrayType(dataType, _) => ArrayType(dataType, containsNull)
+          case x =>
+            // This should not happen.
+            throw QueryCompilationErrors.dataTypeUnsupportedByClassError(
+              x, "KnownNotContainsNull")
+        }
+        selectField(child, opt)
       case _ =>
         None
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index af64da7e3820..ae0257522727 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -17,8 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.util.concurrent.atomic.AtomicReference
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
 import org.apache.spark.sql.types._
 
@@ -209,6 +213,150 @@ class SchemaPruningSuite extends SparkFunSuite with 
SQLHelper {
     }
   }
 
+  test("merge returned and lambda fields for array higher-order functions") {
+    val elementType = StructType.fromDDL("a int, b int, c int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val argument = GetStructField(event, 0, Some("rules"))
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val predicate = LambdaFunction(
+      GreaterThan(GetStructField(element, 2, Some("c")), Literal(0)),
+      Seq(element))
+    val left = NamedLambdaVariable("left", elementType, nullable = true)
+    val right = NamedLambdaVariable("right", elementType, nullable = true)
+    val comparator = LambdaFunction(
+      Coalesce(Seq(
+        Subtract(
+          GetStructField(left, 2, Some("c")),
+          GetStructField(right, 2, Some("c"))),
+        Literal(0))),
+      Seq(left, right))
+
+    Seq(ArrayFilter(argument, predicate), ArraySort(argument, 
comparator)).foreach { function =>
+      val selected = GetArrayStructFields(
+        function,
+        elementType(0),
+        ordinal = 0,
+        numFields = elementType.length,
+        containsNull = true)
+      val rootFields = SchemaPruning.getRootFields(selected)
+      val prunedSchema = SchemaPruning.pruneSchema(
+        StructType(Seq(StructField("event", eventType))),
+        rootFields)
+
+      assert(prunedSchema === StructType.fromDDL(
+        "event struct<rules:array<struct<a:int,c:int>>>"))
+    }
+  }
+
+  test("match separately instantiated array lambda variables by exprId") {
+    val elementType = StructType.fromDDL("a int, b int, c int")
+    val sourceSchema = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val rules = AttributeReference("rules", ArrayType(elementType, 
containsNull = true))()
+
+    def selectFirstField(function: Expression): GetArrayStructFields = {
+      GetArrayStructFields(
+        function,
+        elementType(0),
+        ordinal = 0,
+        numFields = elementType.length,
+        containsNull = true)
+    }
+
+    def evaluateSelectedFields(selected: GetArrayStructFields): Seq[Int] = {
+      val rootFields = SchemaPruning.identifyRootFields(Seq(Alias(selected, 
"out")()), Seq.empty)
+      val prunedSchema = SchemaPruning.pruneSchema(sourceSchema, rootFields)
+      assert(prunedSchema === StructType.fromDDL("rules 
array<struct<a:int,c:int>>"))
+
+      val projected = ProjectionOverSchema(prunedSchema, 
AttributeSet(Seq(rules)))
+        .unapply(selected).get
+      val bound = BindReferences.bindReference(projected, Seq(rules))
+      val array = new GenericArrayData(Array[Any](InternalRow(1, 3), 
InternalRow(2, -1)))
+      val result = bound.eval(InternalRow(array)).asInstanceOf[ArrayData]
+      (0 until result.numElements()).map(result.getInt)
+    }
+
+    val filterArgument = NamedLambdaVariable("x", elementType, nullable = true)
+    val filterReference = filterArgument.copy(value = new 
AtomicReference[Any]())
+    val filter = ArrayFilter(
+      rules,
+      LambdaFunction(
+        GreaterThan(GetStructField(filterReference, 2, Some("c")), Literal(0)),
+        Seq(filterArgument)))
+    assert(evaluateSelectedFields(selectFirstField(filter)) === Seq(1))
+
+    val leftArgument = NamedLambdaVariable("left", elementType, nullable = 
true)
+    val rightArgument = NamedLambdaVariable("right", elementType, nullable = 
true)
+    val leftReference = leftArgument.copy(value = new AtomicReference[Any]())
+    val rightReference = rightArgument.copy(value = new AtomicReference[Any]())
+    val sort = ArraySort(
+      rules,
+      LambdaFunction(
+        Coalesce(Seq(
+          Subtract(
+            GetStructField(leftReference, 2, Some("c")),
+            GetStructField(rightReference, 2, Some("c"))),
+          Literal(0))),
+        Seq(leftArgument, rightArgument)),
+      allowNullComparisonResult = false)
+    assert(evaluateSelectedFields(selectFirstField(sort)) === Seq(2, 1))
+  }
+
+  test("retain full nested array elements for array-returning higher-order 
functions") {
+    val structType = StructType.fromDDL("a int, c int")
+    val arrayType = ArrayType(structType, containsNull = true)
+    val nestedArrayType = ArrayType(arrayType, containsNull = true)
+    val sourceSchema = StructType(Seq(StructField("rules", nestedArrayType)))
+    val rules = AttributeReference("rules", nestedArrayType)()
+
+    def selectFirstField(function: Expression): GetArrayStructFields = {
+      GetArrayStructFields(
+        GetArrayItem(function, Literal(0)),
+        structType(0),
+        ordinal = 0,
+        numFields = structType.length,
+        containsNull = true)
+    }
+
+    val element = NamedLambdaVariable("x", arrayType, nullable = true)
+    val elementC = GetArrayStructFields(
+      element,
+      structType(1),
+      ordinal = 1,
+      numFields = structType.length,
+      containsNull = true)
+    val filter = ArrayFilter(
+      rules,
+      LambdaFunction(GreaterThan(GetArrayItem(elementC, Literal(0)), 
Literal(0)), Seq(element)))
+
+    val left = NamedLambdaVariable("left", arrayType, nullable = true)
+    val right = NamedLambdaVariable("right", arrayType, nullable = true)
+    def firstC(variable: NamedLambdaVariable): Expression = {
+      GetArrayItem(
+        GetArrayStructFields(
+          variable,
+          structType(1),
+          ordinal = 1,
+          numFields = structType.length,
+          containsNull = true),
+        Literal(0))
+    }
+    val sort = ArraySort(
+      rules,
+      LambdaFunction(
+        Coalesce(Seq(Subtract(firstC(left), firstC(right)), Literal(0))),
+        Seq(left, right)),
+      allowNullComparisonResult = false)
+
+    Seq(filter, sort).foreach { function =>
+      val selected = Alias(selectFirstField(function), "out")()
+      val rootFields = SchemaPruning.identifyRootFields(Seq(selected), 
Seq.empty)
+      assert(SchemaPruning.pruneSchema(sourceSchema, rootFields) === 
sourceSchema)
+    }
+  }
+
   test("do not collect ArrayExists and ArrayForAll lambda fields when the 
whole element is used") {
     val elementType = StructType.fromDDL("a int, b int")
     val eventType = StructType(Seq(
@@ -225,4 +373,75 @@ class SchemaPruningSuite extends SparkFunSuite with 
SQLHelper {
           derivedFromAtt = false)))
     }
   }
+
+  test("do not prune ArrayFilter when the whole result is used") {
+    val elementType = StructType.fromDDL("a int, b int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val filtered = ArrayFilter(
+      GetStructField(event, 0, Some("rules")),
+      LambdaFunction(
+        GreaterThan(GetStructField(element, 0, Some("a")), Literal(0)),
+        Seq(element)))
+
+    val rootFields = SchemaPruning.getRootFields(filtered)
+
+    assert(rootFields.contains(
+      SchemaPruning.RootField(
+        StructField("event", eventType, nullable = true),
+        derivedFromAtt = false)))
+  }
+
+  test("do not prune strict ArraySort when the comparator can return null") {
+    val elementType = StructType.fromDDL("a int, b int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val argument = GetStructField(event, 0, Some("rules"))
+    val left = NamedLambdaVariable("left", elementType, nullable = true)
+    val right = NamedLambdaVariable("right", elementType, nullable = true)
+    val comparator = LambdaFunction(Literal.create(null, IntegerType), 
Seq(left, right))
+    val sorted = ArraySort(argument, comparator, allowNullComparisonResult = 
false)
+    val selected = GetArrayStructFields(
+      sorted,
+      elementType(0),
+      ordinal = 0,
+      numFields = elementType.length,
+      containsNull = true)
+
+    val rootFields = SchemaPruning.getRootFields(selected)
+    val prunedSchema = SchemaPruning.pruneSchema(
+      StructType(Seq(StructField("event", eventType))),
+      rootFields)
+
+    assert(prunedSchema === StructType(Seq(StructField("event", eventType))))
+  }
+
+  test("retain input array nullability when pruning through 
KnownNotContainsNull") {
+    val elementType = StructType.fromDDL("a int, b int")
+    val eventType = StructType(Seq(
+      StructField("rules", ArrayType(elementType, containsNull = true))))
+    val event = AttributeReference("event", eventType)()
+    val element = NamedLambdaVariable("x", elementType, nullable = true)
+    val compacted = KnownNotContainsNull(ArrayFilter(
+      GetStructField(event, 0, Some("rules")),
+      LambdaFunction(IsNotNull(element), Seq(element))))
+    val selected = GetArrayStructFields(
+      compacted,
+      elementType(0),
+      ordinal = 0,
+      numFields = elementType.length,
+      containsNull = false)
+
+    val rootFields = SchemaPruning.getRootFields(selected)
+    val prunedSchema = SchemaPruning.pruneSchema(
+      StructType(Seq(StructField("event", eventType))),
+      rootFields)
+    val prunedEventType = 
prunedSchema("event").dataType.asInstanceOf[StructType]
+
+    assert(prunedEventType("rules").dataType ===
+      ArrayType(StructType.fromDDL("a int"), containsNull = true))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 2aebf08286e1..6b8f3495f4a0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -518,6 +518,149 @@ abstract class SchemaPruningSuite
     checkAnswer(query, Row(true) :: Row(true) :: Nil)
   }
 
+  testSchemaPruning("select nested field returned by ArrayFilter") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(org.apache.spark.sql.functions.filter(
+        col("friends"), friend => friend.getField("last") === 
"Smith").getField("first"))
+
+    checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array("Susan")) ::
+      Row(Array.empty[String]) ::
+      Nil)
+  }
+
+  testSchemaPruning("do not prune ArrayFilter when the whole result is used") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(org.apache.spark.sql.functions.filter(
+        col("friends"), friend => friend.getField("last") === "Smith"))
+
+    checkScan(query, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array(Row("Susan", "Z.", "Smith"))) ::
+      Row(Array.empty[Row]) ::
+      Nil)
+  }
+
+  testSchemaPruning("select nested field returned by array functions under 
null checks") {
+    val expressions = Seq(
+      org.apache.spark.sql.functions.filter(
+        col("friends"), friend => friend.getField("last") === "Smith")
+        .getField("first").isNotNull,
+      array_sort(col("friends"), (left, right) =>
+        when(left.getField("last") < right.getField("last"), -1)
+          .when(left.getField("last") > right.getField("last"), 1)
+          .otherwise(0)).getField("first").isNull)
+
+    expressions.foreach { expression =>
+      val query = spark.table("contacts")
+        .where("p = 1")
+        .select(expression)
+
+      checkScan(query, 
"struct<friends:array<struct<first:string,last:string>>>")
+    }
+  }
+
+  testSchemaPruning("select nested field returned by ArraySort") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(array_sort(col("friends"), (left, right) =>
+        when(left.getField("last") < right.getField("last"), -1)
+          .when(left.getField("last") > right.getField("last"), 1)
+          .otherwise(0)).getField("first"))
+
+    checkScan(query, "struct<friends:array<struct<first:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array("Susan")) ::
+      Row(Array.empty[String]) ::
+      Nil)
+  }
+
+  testSchemaPruning("do not prune default ArraySort when selecting a nested 
field") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(array_sort(col("friends")).getField("first"))
+
+    checkScan(query, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query,
+      Row(Array("Susan")) ::
+      Row(Array.empty[String]) ::
+      Nil)
+  }
+
+  testSchemaPruning("select nested field returned by Array wrappers") {
+    val queries = Seq(
+      reverse(col("friends")).getField("first"),
+      shuffle(col("friends")).getField("first"),
+      slice(col("friends"), 1, 1).getField("first"))
+
+    queries.foreach { expression =>
+      val query = spark.table("contacts")
+        .where("p = 1")
+        .select(expression)
+
+      checkScan(query, "struct<friends:array<struct<first:string>>>")
+      checkAnswer(query,
+        Row(Array("Susan")) ::
+        Row(Array.empty[String]) ::
+        Nil)
+    }
+  }
+
+  testSchemaPruning("do not prune through Slice with non-foldable bounds") {
+    val query = spark.table("contacts")
+      .where("p = 1")
+      .select(
+        slice(col("friends"), col("employer.id") + 1, 
lit(1)).getField("first"),
+        col("employer.company.name"))
+
+    checkScan(query,
+      "struct<friends:array<struct<first:string,middle:string,last:string>>," +
+        "employer:struct<id:int,company:struct<name:string>>>")
+    checkAnswer(query,
+      Row(Array("Susan"), "abc") ::
+      Row(Array.empty[String], null) ::
+      Nil)
+  }
+
+  testSchemaPruning("select ArrayTransform over array-returning higher-order 
functions") {
+    val expressions = Seq(
+      transform(
+        org.apache.spark.sql.functions.filter(
+          col("friends"), friend => friend.getField("last") === "Smith"),
+        friend => friend.getField("first")),
+      transform(
+        array_sort(col("friends"), (left, right) =>
+          when(left.getField("last") < right.getField("last"), -1)
+            .when(left.getField("last") > right.getField("last"), 1)
+            .otherwise(0)),
+        friend => friend.getField("first")))
+
+    expressions.foreach { expression =>
+      val query = spark.table("contacts")
+        .where("p = 1")
+        .select(expression)
+
+      checkScan(query, 
"struct<friends:array<struct<first:string,last:string>>>")
+      checkAnswer(query,
+        Row(Array("Susan")) ::
+        Row(Array.empty[String]) ::
+        Nil)
+    }
+  }
+
+  testSchemaPruning("select nested field returned by ArrayCompact with null 
elements") {
+    withDataSourceTable(organizations, "organizations") {
+      val query = spark.table("organizations")
+        .select(array_compact(col("team.members")).getField("id"))
+
+      checkScan(query, "struct<team:struct<members:array<struct<id:int>>>>")
+      checkAnswer(query, Row(Array(1, 0)) :: Nil)
+    }
+  }
+
   testSchemaPruning("SPARK-34638: nested column prune on generator output") {
     val query1 = spark.table("contacts")
       .select(explode(col("friends")).as("friend"))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to