This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2adc69aca3de [SPARK-56132][SS] Call pruneColumns on V2 streaming to 
fix metadata reading issue
2adc69aca3de is described below

commit 2adc69aca3de7415d011fbe4b3dedaa3a987e8d5
Author: Zikang Han <[email protected]>
AuthorDate: Thu May 28 13:19:21 2026 -0700

    [SPARK-56132][SS] Call pruneColumns on V2 streaming to fix metadata reading 
issue
    
    ### What changes were proposed in this pull request?
    
    In `MicroBatchExecution.logicalPlan`, before calling `build()` on the V2 
streaming scan
    builder, call 
`SupportsPushDownRequiredColumns.pruneColumns(output.toStructType)` if the
    builder supports it. `output` is the analyzed relation output, which 
already includes any
    metadata columns the query references (added by the `AddMetadataColumns` 
rule).
    
    ### Why are the changes needed?
    
    1. **Metadata column reads in V2 streaming crash with 
`ArrayIndexOutOfBoundsException`.**
       When a query selects a metadata column (e.g. `_metadata.row_id`) from a 
V2 streaming
       source that implements both `SupportsMetadataColumns` and 
`SupportsPushDownRequiredColumns`,
       the analyzed plan expects the metadata column in the scan output, but 
`Scan.readSchema()`
       does not include it. Spark tries to read a column at an index the scan 
never produced.
    
    2. **Root cause: `pruneColumns` is never called in streaming.**
       In batch, `V2ScanRelationPushDown` calls 
`SupportsPushDownRequiredColumns.pruneColumns`
       with the required schema (which includes metadata columns resolved by 
`AddMetadataColumns`)
       before `build()`. In `MicroBatchExecution.logicalPlan`, the scan is 
built directly with
       `table.newScanBuilder(options).build()` — no pushdown of any kind is 
applied (a
       `// TODO: operator pushdown` comment marks this). Connectors that use 
`pruneColumns` to
       configure `readSchema()` — including whether to produce metadata columns 
— are never
       informed of what the query needs.
    
    3. **This change fixes metadata column reads only, not column pruning.**
       We call `pruneColumns(output.toStructType)` where `output` is the full 
analyzed relation
       output — all data columns plus any metadata columns added by 
`AddMetadataColumns`. This
       communicates required metadata columns to the scan builder so they 
appear in `readSchema()`,
       but does not prune data columns. Full column pruning in streaming, along 
with filter and
       aggregate pushdown, is deferred to the existing TODO.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Fixes a bug associated with metadata columns.
    
    ### How was this patch tested?
    
    Added a test in `DataStreamTableAPISuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes
    
    Closes #56133 from zikangh/stack/prunecolumns-streaming.
    
    Authored-by: Zikang Han <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../sql/connector/catalog/InMemoryBaseTable.scala  | 42 ++++++++++---
 .../streaming/continuous/ContinuousExecution.scala | 11 +++-
 .../streaming/runtime/MicroBatchExecution.scala    | 11 +++-
 .../spark/sql/connector/MetadataColumnSuite.scala  | 26 ++++++++
 .../streaming/test/DataStreamTableAPISuite.scala   | 73 +++++++++++++++++++++-
 5 files changed, 150 insertions(+), 13 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index f582f3e408cb..06997662fd8b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -545,7 +545,8 @@ abstract class InMemoryBaseTable(
     override def json(): String = rowCount.toString
   }
 
-  class InMemoryMicroBatchStream extends MicroBatchStream {
+  class InMemoryMicroBatchStream(readSchema: StructType, tableSchema: 
StructType)
+      extends MicroBatchStream {
     override def initialOffset(): Offset = new InMemoryTableOffset(0)
     override def latestOffset(): Offset =
       new InMemoryTableOffset(InMemoryBaseTable.this.rows.size.toLong)
@@ -554,14 +555,13 @@ abstract class InMemoryBaseTable(
       val e = end.asInstanceOf[InMemoryTableOffset].rowCount.toInt
       Array(InMemoryMicroBatchPartition(InMemoryBaseTable.this.rows.slice(s, 
e)))
     }
-    override def createReaderFactory(): PartitionReaderFactory = { partition =>
-      val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows
-      new PartitionReader[InternalRow] {
-        private var idx = -1
-        override def next(): Boolean = { idx += 1; idx < rows.size }
-        override def get(): InternalRow = rows(idx)
-        override def close(): Unit = {}
+    override def createReaderFactory(): PartitionReaderFactory = {
+      val metadataColNames = new mutable.ArrayBuffer[String]()
+      readSchema.foreach {
+        case MetadataStructFieldWithLogicalName(_, name) => metadataColNames 
+= name
+        case _ =>
       }
+      new InMemoryMicroBatchReaderFactory(metadataColNames.toArray)
     }
     override def deserializeOffset(json: String): Offset = new 
InMemoryTableOffset(json.toLong)
     override def commit(end: Offset): Unit = {}
@@ -655,7 +655,7 @@ abstract class InMemoryBaseTable(
     }
 
     override def toMicroBatchStream(checkpointLocation: String): 
MicroBatchStream =
-      new InMemoryMicroBatchStream
+      new InMemoryMicroBatchStream(readSchema, tableSchema)
   }
 
   case class InMemoryBatchScan(
@@ -954,6 +954,30 @@ class BufferedRows(val key: Seq[Any], val schema: 
StructType)
   def clear(): Unit = rows.clear()
 }
 
+private class InMemoryMicroBatchReaderFactory(
+    metaNames: Array[String]) extends PartitionReaderFactory with Serializable 
{
+  override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
+    val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows
+    new PartitionReader[InternalRow] {
+      private var idx = -1
+      override def next(): Boolean = { idx += 1; idx < rows.size }
+      override def get(): InternalRow = {
+        val rawRow = rows(idx)
+        if (metaNames.isEmpty) rawRow
+        else {
+          val metaRow = new GenericInternalRow(metaNames.map {
+            case "index" => idx.asInstanceOf[Any]
+            case "_partition" => UTF8String.fromString("").asInstanceOf[Any]
+            case _ => null
+          })
+          new JoinedRow(rawRow, metaRow)
+        }
+      }
+      override def close(): Unit = {}
+    }
+  }
+}
+
 object BufferedRows {
   def apply(key: Seq[Any], schema: Array[Column]): BufferedRows = {
     new BufferedRows(key, CatalogV2Util.v2ColumnsToStructType(schema))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 14cd06038b5a..4c7a8437a46f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -33,6 +33,7 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
 import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
TableCapability}
 import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution
+import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns
 import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, 
PartitionOffset, ReadLimit, SparkDataStream}
 import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, 
Write}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
@@ -92,7 +93,15 @@ class ContinuousExecution(
             log"from DataSourceV2 named '${MDC(STREAMING_DATA_SOURCE_NAME, 
sourceName)}' " +
             log"${MDC(STREAMING_DATA_SOURCE_DESCRIPTION, dsStr)}")
           // TODO: operator pushdown.
-          val scan = table.newScanBuilder(options).build()
+          // Passes the full output schema (not a pruned subset) so that 
connectors
+          // implementing SupportsMetadataColumns can include metadata columns 
in readSchema().
+          val scanBuilder = table.newScanBuilder(options)
+          scanBuilder match {
+            case r: SupportsPushDownRequiredColumns =>
+              r.pruneColumns(output.toStructType)
+            case _ =>
+          }
+          val scan = scanBuilder.build()
           val stream = scan.toContinuousStream(metadataPath)
           val relation = StreamingDataSourceV2Relation(
               table, output, catalog, identifier, options, metadataPath)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
index 726586ac72e6..84f0373ca5d4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.classic.{Dataset, SparkSession}
 import org.apache.spark.sql.classic.ClassicConversions.castToImpl
 import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
TableCapability, TransactionalCatalogPlugin}
+import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns
 import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset 
=> OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl, 
SupportsRealTimeMode, SupportsTriggerAvailableNow}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
@@ -250,7 +251,15 @@ class MicroBatchExecution(
               log"from DataSourceV2 named 
'${MDC(LogKeys.STREAMING_DATA_SOURCE_NAME, srcName)}' " +
               log"${MDC(LogKeys.STREAMING_DATA_SOURCE_DESCRIPTION, dsStr)}")
             // TODO: operator pushdown.
-            val scan = table.newScanBuilder(options).build()
+            // Passes the full output schema (not a pruned subset) so that 
connectors
+            // implementing SupportsMetadataColumns can include metadata 
columns in readSchema().
+            val scanBuilder = table.newScanBuilder(options)
+            scanBuilder match {
+              case r: SupportsPushDownRequiredColumns =>
+                r.pruneColumns(output.toStructType)
+              case _ =>
+            }
+            val scan = scanBuilder.build()
             val stream = scan.toMicroBatchStream(metadataPath)
             val relation = StreamingDataSourceV2Relation(
                 table,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
index fe338175ec88..77e3818aafe8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
@@ -376,6 +376,32 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
     }
   }
 
+  test("SPARK-56132: streaming read of metadata columns from V2 source") {
+    withTable(tbl) {
+      prepareTable()
+      withTempDir { checkpointDir =>
+        // "index" is a metadata column (not in the table schema); "id" and 
"data" are data columns.
+        val df = spark.readStream.table(tbl).select("id", "data", "index")
+        val q = df.writeStream
+          .format("memory")
+          .queryName("result_56132")
+          .option("checkpointLocation", checkpointDir.getCanonicalPath)
+          .start()
+        try {
+          q.processAllAvailable()
+          val result = spark.table("result_56132")
+          // Verify data columns arrive correctly and index (metadata) is 
non-null.
+          checkAnswer(result.select("id", "data").orderBy("id"),
+            Seq(Row(1, "a"), Row(2, "b"), Row(3, "c")))
+          assert(result.select("index").collect().forall(!_.isNullAt(0)),
+            "index metadata column should be non-null in streaming output")
+        } finally {
+          q.stop()
+        }
+      }
+    }
+  }
+
   test("SPARK-43123: Metadata column related field metadata should not be 
leaked to catalogs") {
     withTable(tbl, "testcat.target") {
       prepareTable()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index f10d1cdab0d5..3930beec084d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -30,7 +30,8 @@ import 
org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
 import org.apache.spark.sql.connector.{FakeV2Provider, 
FakeV2ProviderWithCustomSchema, InMemoryTableSessionCatalog}
 import org.apache.spark.sql.connector.catalog.{Column, Identifier, 
InMemoryTable, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns, 
SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback}
 import org.apache.spark.sql.connector.expressions.{ClusterByTransform, 
FieldReference, Transform}
-import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
MemoryStreamScanBuilder, StreamingQueryWrapper}
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.internal.SQLConf
@@ -564,6 +565,42 @@ class DataStreamTableAPISuite extends StreamTest with 
BeforeAndAfter {
     }
   }
 
+  test("SPARK-56132: pruneColumns called on SupportsPushDownRequiredColumns " +
+      "V2 streaming scan builder") {
+    val tblName = "teststream.table_name"
+    withTable(tblName) {
+      spark.sql(s"CREATE TABLE $tblName (data int) USING foo")
+      val stream = MemoryStream[Int]
+      val testCatalog = 
spark.sessionState.catalogManager.catalog("teststream").asTableCatalog
+      val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+        .asInstanceOf[InMemoryStreamTable]
+      table.setStream(stream)
+
+      // Wrap the table's scan builder so we can record pruneColumns calls.
+      val recorded = new PrunedSchemaRecorder
+      table.scanBuilderWrapper = Some(inner => new 
RecordingPruneScanBuilder(inner, recorded))
+
+      withTempDir { checkpointDir =>
+        val q = spark.readStream.table(tblName)
+          .select("value", "_seq")
+          .writeStream.format("noop")
+          .option("checkpointLocation", checkpointDir.getCanonicalPath)
+          .start()
+        try {
+          // logicalPlan is initialized lazily when the query thread starts; 
wait for it.
+          eventually(timeout(streamingTimeout)) {
+            assert(recorded.called,
+              "pruneColumns should have been called on the streaming scan 
builder")
+          }
+          assert(recorded.schema.fieldNames.toSet === Set("value", "_seq"),
+            s"Expected pruneColumns to receive {value, _seq}, got 
${recorded.schema}")
+        } finally {
+          q.stop()
+        }
+      }
+    }
+  }
+
   private def checkForStreamTable(dir: Option[File], tableName: String): Unit 
= {
     val memory = MemoryStream[Int]
     val dsw = memory.toDS().writeStream.format("parquet")
@@ -683,6 +720,7 @@ class InMemoryStreamTable(override val name: String)
   with SupportsRead
   with SupportsMetadataColumns {
   var stream: MemoryStream[Int] = _
+  var scanBuilderWrapper: Option[MemoryStreamScanBuilder => ScanBuilder] = None
 
   def setStream(inputData: MemoryStream[Int]): Unit = stream = inputData
 
@@ -693,7 +731,8 @@ class InMemoryStreamTable(override val name: String)
   }
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
-    new MemoryStreamScanBuilder(stream)
+    val inner = new MemoryStreamScanBuilder(stream)
+    scanBuilderWrapper.map(_(inner)).getOrElse(inner)
   }
 
   private object SeqColumn extends MetadataColumn {
@@ -705,6 +744,36 @@ class InMemoryStreamTable(override val name: String)
   override val metadataColumns: Array[MetadataColumn] = Array(SeqColumn)
 }
 
+class PrunedSchemaRecorder {
+  @volatile var called = false
+  @volatile var schema: StructType = new StructType()
+}
+
+class RecordingPruneScanBuilder(inner: MemoryStreamScanBuilder, recorder: 
PrunedSchemaRecorder)
+    extends ScanBuilder
+    with SupportsPushDownRequiredColumns {
+
+  override def pruneColumns(requiredSchema: StructType): Unit = {
+    recorder.called = true
+    recorder.schema = requiredSchema
+  }
+
+  override def build(): Scan = {
+    val innerScan = inner.build()
+    val prunedSchema = recorder.schema
+    // Return a scan whose readSchema() reflects the pruned schema so the 
streaming plan
+    // and scan agree on output columns. Without the fix, pruneColumns is 
never called and
+    // readSchema() defaults to the full table schema, causing 
ArrayIndexOutOfBoundsException
+    // when metadata columns are in the plan output but absent from the scan 
output.
+    new Scan {
+      override def readSchema(): StructType =
+        if (recorder.called) prunedSchema else innerScan.readSchema()
+      override def toMicroBatchStream(checkpointLocation: String): 
MicroBatchStream =
+        innerScan.toMicroBatchStream(checkpointLocation)
+    }
+  }
+}
+
 class NonStreamV2Table(override val name: String)
     extends Table with SupportsRead with V2TableWithV1Fallback {
   override def schema(): StructType = StructType(Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to