This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a35c9f36da63 [SPARK-53805][SQL] Push Variant into DSv2 scan
a35c9f36da63 is described below

commit a35c9f36da63202d68dde70f5ad3058ba7357715
Author: Huaxin Gao <[email protected]>
AuthorDate: Fri Oct 10 12:15:59 2025 -0700

    [SPARK-53805][SQL] Push Variant into DSv2 scan
    
    ### What changes were proposed in this pull request?
    Push Variant into DSv2 scan
    
    ### Why are the changes needed?
    with the change, DSV2 scan only needs to fetch the necessary shredded 
columns required by the plan
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52522 from huaxingao/variant-v2-pushdown.
    
    Authored-by: Huaxin Gao <[email protected]>
    Signed-off-by: Huaxin Gao <[email protected]>
---
 .../spark/sql/execution/SparkOptimizer.scala       |   4 +-
 .../datasources/PushVariantIntoScan.scala          | 109 ++++++++++++---
 .../datasources/v2/VariantV2ReadSuite.scala        | 148 +++++++++++++++++++++
 3 files changed, 242 insertions(+), 19 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 8edb59f49282..9699d8a2563f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -40,11 +40,11 @@ class SparkOptimizer(
       SchemaPruning,
       GroupBasedRowLevelOperationScanPlanning,
       V1Writes,
+      PushVariantIntoScan,
       V2ScanRelationPushDown,
       V2ScanPartitioningAndOrdering,
       V2Writes,
-      PruneFileSourcePartitions,
-      PushVariantIntoScan)
+      PruneFileSourcePartitions)
 
   override def preCBORules: Seq[Rule[LogicalPlan]] =
     Seq(OptimizeMetadataOnlyDeleteFromTable)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 5960cf8c38ce..6ce53e3367c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
       relation @ LogicalRelationWithTable(
       hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), 
_)) =>
         rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
+      case p@PhysicalOperation(projectList, filters, relation: 
DataSourceV2Relation) =>
+        rewriteV2RelationPlan(p, projectList, filters, relation)
     }
   }
 
@@ -288,23 +291,91 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
       filters: Seq[Expression],
       relation: LogicalRelation,
       hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
-    val variants = new VariantInRelation
-
     val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema,
       hadoopFsRelation.sparkSession.sessionState.analyzer.resolver)
-    val defaultValues = 
ResolveDefaultColumns.existenceDefaultValues(StructType(
-      schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, 
a.metadata))))
-    for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
-      variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+
+    // Collect variant fields from the relation output
+    val variants = collectAndRewriteVariants(schemaAttributes)
+    if (variants.mapping.isEmpty) return originalPlan
+
+    // Collect requested fields from projections and filters
+    projectList.foreach(variants.collectRequestedFields)
+    filters.foreach(variants.collectRequestedFields)
+    // `collectRequestedFields` may have removed all variant columns.
+    if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
+
+    // Build attribute map with rewritten types
+    val attributeMap = buildAttributeMap(schemaAttributes, variants)
+
+    // Build new schema with variant types replaced by struct types
+    val newFields = schemaAttributes.map { a =>
+      val dataType = attributeMap(a.exprId).dataType
+      StructField(a.name, dataType, a.nullable, a.metadata)
     }
+    // Update relation output attributes with new types
+    val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, 
a))
+
+    // Update HadoopFsRelation's data schema so the file source reads the 
struct columns
+    val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = 
StructType(newFields))(
+      hadoopFsRelation.sparkSession)
+    val newRelation = relation.copy(relation = newHadoopFsRelation, output = 
newOutput.toIndexedSeq)
+
+    // Build filter and project with rewritten expressions
+    buildFilterAndProject(newRelation, projectList, filters, variants, 
attributeMap)
+  }
+
+  private def rewriteV2RelationPlan(
+      originalPlan: LogicalPlan,
+      projectList: Seq[NamedExpression],
+      filters: Seq[Expression],
+      relation: DataSourceV2Relation): LogicalPlan = {
+
+    // Collect variant fields from the relation output
+    val variants = collectAndRewriteVariants(relation.output)
     if (variants.mapping.isEmpty) return originalPlan
 
+    // Collect requested fields from projections and filters
     projectList.foreach(variants.collectRequestedFields)
     filters.foreach(variants.collectRequestedFields)
     // `collectRequestedFields` may have removed all variant columns.
     if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
 
-    val attributeMap = schemaAttributes.map { a =>
+    // Build attribute map with rewritten types
+    val attributeMap = buildAttributeMap(relation.output, variants)
+
+    // Update relation output attributes with new types
+    // Note: DSv2 doesn't need to update the schema in the relation itself. 
The schema will be
+    // communicated to the data source later via 
V2ScanRelationPushDown.pruneColumns() API.
+    val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, 
a))
+    val newRelation = relation.copy(output = newOutput.toIndexedSeq)
+
+    // Build filter and project with rewritten expressions
+    buildFilterAndProject(newRelation, projectList, filters, variants, 
attributeMap)
+  }
+
+  /**
+   * Collect variant fields and return initialized VariantInRelation.
+   */
+  private def collectAndRewriteVariants(
+      schemaAttributes: Seq[Attribute]): VariantInRelation = {
+    val variants = new VariantInRelation
+    val defaultValues = 
ResolveDefaultColumns.existenceDefaultValues(StructType(
+      schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable, 
a.metadata))))
+
+    for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+      variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+    }
+
+    variants
+  }
+
+  /**
+   * Build attribute map with rewritten variant types.
+   */
+  private def buildAttributeMap(
+      schemaAttributes: Seq[Attribute],
+      variants: VariantInRelation): Map[ExprId, AttributeReference] = {
+    schemaAttributes.map { a =>
       if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
         val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
         val newAttr = AttributeReference(a.name, newType, a.nullable, 
a.metadata)(
@@ -316,21 +387,24 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
         (a.exprId, a.asInstanceOf[AttributeReference])
       }
     }.toMap
-    val newFields = schemaAttributes.map { a =>
-      val dataType = attributeMap(a.exprId).dataType
-      StructField(a.name, dataType, a.nullable, a.metadata)
-    }
-    val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId, 
a))
+  }
 
-    val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema = 
StructType(newFields))(
-      hadoopFsRelation.sparkSession)
-    val newRelation = relation.copy(relation = newHadoopFsRelation, output = 
newOutput.toIndexedSeq)
+  /**
+   * Build the final Project(Filter(relation)) plan with rewritten expressions.
+   */
+  private def buildFilterAndProject(
+      relation: LogicalPlan,
+      projectList: Seq[NamedExpression],
+      filters: Seq[Expression],
+      variants: VariantInRelation,
+      attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
 
     val withFilter = if (filters.nonEmpty) {
-      Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), 
newRelation)
+      Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), 
relation)
     } else {
-      newRelation
+      relation
     }
+
     val newProjectList = projectList.map { e =>
       val rewritten = variants.rewriteExpr(e, attributeMap)
       rewritten match {
@@ -341,6 +415,7 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
         case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
       }
     }
+
     Project(newProjectList, withFilter)
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala
new file mode 100644
index 000000000000..a6521dfe76da
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.execution.datasources.VariantMetadata
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType, 
VariantType}
+
+class VariantV2ReadSuite extends QueryTest with SharedSparkSession {
+
+  private val testCatalogClass = 
"org.apache.spark.sql.connector.catalog.InMemoryTableCatalog"
+
+  private def withV2Catalog(f: => Unit): Unit = {
+    withSQLConf(
+      SQLConf.DEFAULT_CATALOG.key -> "testcat",
+      s"spark.sql.catalog.testcat" -> testCatalogClass,
+      SQLConf.USE_V1_SOURCE_LIST.key -> "",
+      SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true",
+      SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") {
+      f
+    }
+  }
+
+  test("DSV2: push variant_get fields") {
+    withV2Catalog {
+      sql("DROP TABLE IF EXISTS testcat.ns.users")
+      sql(
+        """CREATE TABLE testcat.ns.users (
+          |  id bigint,
+          |  name string,
+          |  v variant,
+          |  vd variant default parse_json('1')
+          |) USING parquet""".stripMargin)
+
+      val out = sql(
+        """
+          |SELECT
+          |  id,
+          |  variant_get(v, '$.username', 'string') as username,
+          |  variant_get(v, '$.age', 'int') as age
+          |FROM testcat.ns.users
+          |WHERE variant_get(v, '$.status', 'string') = 'active'
+          |""".stripMargin)
+
+      checkAnswer(out, Seq.empty)
+
+      // Verify variant column rewrite
+      val optimized = out.queryExecution.optimizedPlan
+      val relOutput = optimized.collectFirst {
+        case s: DataSourceV2ScanRelation => s.output
+      }.getOrElse(fail("Expected DSv2 relation in optimized plan"))
+
+      val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v' 
column"))
+      vAttr.dataType match {
+        case s: StructType =>
+          assert(s.fields.length == 3,
+            s"Expected 3 fields (username, age, status), got 
${s.fields.length}")
+          
assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)),
+            "All fields should have VariantMetadata")
+
+          val paths = s.fields.map(f => 
VariantMetadata.fromMetadata(f.metadata).path).toSet
+          assert(paths == Set("$.username", "$.age", "$.status"),
+            s"Expected username, age, status paths, got: $paths")
+
+          val fieldTypes = s.fields.map(_.dataType).toSet
+          assert(fieldTypes.contains(StringType), "Expected StringType for 
string fields")
+          assert(fieldTypes.contains(IntegerType), "Expected IntegerType for 
age")
+
+        case other =>
+          fail(s"Expected StructType for 'v', got: $other")
+      }
+
+      // Verify variant with default value is NOT rewritten
+      relOutput.find(_.name == "vd").foreach { vdAttr =>
+        assert(vdAttr.dataType == VariantType,
+          "Variant column with default value should not be rewritten")
+      }
+    }
+  }
+
+  test("DSV2: nested column pruning for variant struct") {
+    withV2Catalog {
+      sql("DROP TABLE IF EXISTS testcat.ns.users2")
+      sql(
+        """CREATE TABLE testcat.ns.users2 (
+          |  id bigint,
+          |  name string,
+          |  v variant
+          |) USING parquet""".stripMargin)
+
+      val out = sql(
+        """
+          |SELECT id, variant_get(v, '$.username', 'string') as username
+          |FROM testcat.ns.users2
+          |""".stripMargin)
+
+      checkAnswer(out, Seq.empty)
+
+      val scan = out.queryExecution.executedPlan.collectFirst {
+        case b: BatchScanExec => b.scan
+      }.getOrElse(fail("Expected BatchScanExec in physical plan"))
+
+      val readSchema = scan.readSchema()
+
+      // Verify 'v' field exists and is a struct
+      val vField = readSchema.fields.find(_.name == "v").getOrElse(
+        fail("Expected 'v' field in read schema")
+      )
+
+      vField.dataType match {
+        case s: StructType =>
+          assert(s.fields.length == 1,
+            "Expected only 1 field ($.username) in pruned schema, got " + 
s.fields.length + ": " +
+              s.fields.map(f => 
VariantMetadata.fromMetadata(f.metadata).path).mkString(", "))
+
+          val field = s.fields(0)
+          assert(field.metadata.contains(VariantMetadata.METADATA_KEY),
+            "Field should have VariantMetadata")
+
+          val metadata = VariantMetadata.fromMetadata(field.metadata)
+          assert(metadata.path == "$.username",
+            "Expected path '$.username', got '" + metadata.path + "'")
+          assert(field.dataType == StringType,
+            s"Expected StringType, got ${field.dataType}")
+
+        case other =>
+          fail(s"Expected StructType for 'v' after rewrite and pruning, got: 
$other")
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to