This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new dff6402457 Let `ArrowArrayStreamReader` handle schema with attached 
metadata + do schema checking (#8944)
dff6402457 is described below

commit dff6402457910389af785d768e0f9e274d9526a7
Author: Jonas Dedden <[email protected]>
AuthorDate: Tue Dec 9 23:37:54 2025 +0100

    Let `ArrowArrayStreamReader` handle schema with attached metadata + do 
schema checking (#8944)
    
    # Which issue does this PR close? / Rationale for this change
    
    Solves an issue discovered during
    https://github.com/apache/arrow-rs/pull/8790, namely that
    `ArrowArrayStreamReader` does not correctly expose schema-level metadata
    and does not check whether the `StructArray` constructed from the FFI
    stream actually in general corresponds to the expected schema.
    
    # What changes are included in this PR?
    
    - Change how `RecordBatch` is construted inside `ArrowArrayStreamReader`
    such that it holds metadata and schema validity checks are done.
    - Augment FFI tests with schema- and column-level metadata.
    
    # Are these changes tested?
    
    Yes, both `_test_round_trip_export` and `_test_round_trip_import` now
    test for metadata on schema- and column-level.
    
    # Are there any user-facing changes?
    
    Yes, `ArrowArrayStreamReader` now is able to export `RecordBatch` with
    schema-level metadata, and the interface has increased security since it
    actually checks for schema validity.
---
 arrow-array/src/ffi_stream.rs                      | 45 ++++++++++++++++------
 .../tests/test_sql.py                              | 28 ++++++++------
 2 files changed, 49 insertions(+), 24 deletions(-)

diff --git a/arrow-array/src/ffi_stream.rs b/arrow-array/src/ffi_stream.rs
index 27c020e5c0..c469436829 100644
--- a/arrow-array/src/ffi_stream.rs
+++ b/arrow-array/src/ffi_stream.rs
@@ -364,7 +364,9 @@ impl Iterator for ArrowArrayStreamReader {
             let result = unsafe {
                 from_ffi_and_data_type(array, 
DataType::Struct(self.schema().fields().clone()))
             };
-            Some(result.map(|data| RecordBatch::from(StructArray::from(data))))
+            Some(result.and_then(|data| {
+                RecordBatch::try_new(self.schema.clone(), 
StructArray::from(data).into_parts().1)
+            }))
         } else {
             let last_error = self.get_stream_last_error();
             let err = ArrowError::CDataInterface(last_error.unwrap());
@@ -382,6 +384,7 @@ impl RecordBatchReader for ArrowArrayStreamReader {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use std::collections::HashMap;
 
     use arrow_schema::Field;
 
@@ -417,11 +420,18 @@ mod tests {
     }
 
     fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
-        let schema = Arc::new(Schema::new(vec![
-            Field::new("a", arrays[0].data_type().clone(), true),
-            Field::new("b", arrays[1].data_type().clone(), true),
-            Field::new("c", arrays[2].data_type().clone(), true),
-        ]));
+        let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
+        let schema = Arc::new(Schema::new_with_metadata(
+            vec![
+                Field::new("a", arrays[0].data_type().clone(), true)
+                    .with_metadata(metadata.clone()),
+                Field::new("b", arrays[1].data_type().clone(), true)
+                    .with_metadata(metadata.clone()),
+                Field::new("c", arrays[2].data_type().clone(), true)
+                    .with_metadata(metadata.clone()),
+            ],
+            metadata,
+        ));
         let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
         let iter = Box::new(vec![batch.clone(), 
batch.clone()].into_iter().map(Ok)) as _;
 
@@ -452,7 +462,11 @@ mod tests {
 
             let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
 
-            let record_batch = RecordBatch::from(StructArray::from(array));
+            let record_batch = RecordBatch::try_new(
+                SchemaRef::from(exported_schema.clone()),
+                StructArray::from(array).into_parts().1,
+            )
+            .unwrap();
             produced_batches.push(record_batch);
         }
 
@@ -462,11 +476,18 @@ mod tests {
     }
 
     fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
-        let schema = Arc::new(Schema::new(vec![
-            Field::new("a", arrays[0].data_type().clone(), true),
-            Field::new("b", arrays[1].data_type().clone(), true),
-            Field::new("c", arrays[2].data_type().clone(), true),
-        ]));
+        let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
+        let schema = Arc::new(Schema::new_with_metadata(
+            vec![
+                Field::new("a", arrays[0].data_type().clone(), true)
+                    .with_metadata(metadata.clone()),
+                Field::new("b", arrays[1].data_type().clone(), true)
+                    .with_metadata(metadata.clone()),
+                Field::new("c", arrays[2].data_type().clone(), true)
+                    .with_metadata(metadata.clone()),
+            ],
+            metadata,
+        ));
         let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
         let iter = Box::new(vec![batch.clone(), 
batch.clone()].into_iter().map(Ok)) as _;
 
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py 
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 79220fb6a6..b9b04ddee5 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -527,7 +527,7 @@ def test_empty_recordbatch_with_row_count():
     """
 
     # Create an empty schema with no fields
-    batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([])
+    batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}, metadata={b'key1': 
b'value1'}).select([])
     num_rows = 4
     assert batch.num_rows == num_rows
     assert batch.num_columns == 0
@@ -545,7 +545,7 @@ def test_record_batch_reader():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     batches = [
         pa.record_batch([[[1], [2, 42]]], schema),
         pa.record_batch([[None, [], [5, 6]]], schema),
@@ -571,7 +571,7 @@ def test_record_batch_reader_pycapsule():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     batches = [
         pa.record_batch([[[1], [2, 42]]], schema),
         pa.record_batch([[None, [], [5, 6]]], schema),
@@ -621,7 +621,7 @@ def test_record_batch_pycapsule():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     batch = pa.record_batch([[[1], [2, 42]]], schema)
     wrapped = StreamWrapper(batch)
     b = rust.round_trip_record_batch_reader(wrapped)
@@ -640,7 +640,7 @@ def test_table_pycapsule():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     batches = [
         pa.record_batch([[[1], [2, 42]]], schema),
         pa.record_batch([[None, [], [5, 6]]], schema),
@@ -650,8 +650,9 @@ def test_table_pycapsule():
     b = rust.round_trip_record_batch_reader(wrapped)
     new_table = b.read_all()
 
-    assert table.schema == new_table.schema
     assert table == new_table
+    assert table.schema == new_table.schema
+    assert table.schema.metadata == new_table.schema.metadata
     assert len(table.to_batches()) == len(new_table.to_batches())
 
 
@@ -659,12 +660,13 @@ def test_table_empty():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     table = pa.Table.from_batches([], schema=schema)
     new_table = rust.build_table([], schema=schema)
 
-    assert table.schema == new_table.schema
     assert table == new_table
+    assert table.schema == new_table.schema
+    assert table.schema.metadata == new_table.schema.metadata
     assert len(table.to_batches()) == len(new_table.to_batches())
 
 
@@ -672,7 +674,7 @@ def test_table_roundtrip():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))])
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     batches = [
         pa.record_batch([[[1], [2, 42]]], schema),
         pa.record_batch([[None, [], [5, 6]]], schema),
@@ -680,8 +682,9 @@ def test_table_roundtrip():
     table = pa.Table.from_batches(batches, schema=schema)
     new_table = rust.round_trip_table(table)
 
-    assert table.schema == new_table.schema
     assert table == new_table
+    assert table.schema == new_table.schema
+    assert table.schema.metadata == new_table.schema.metadata
     assert len(table.to_batches()) == len(new_table.to_batches())
 
 
@@ -689,7 +692,7 @@ def test_table_from_batches():
     """
     Python -> Rust -> Python
     """
-    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()), 
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
     batches = [
         pa.record_batch([[[1], [2, 42]]], schema),
         pa.record_batch([[None, [], [5, 6]]], schema),
@@ -697,8 +700,9 @@ def test_table_from_batches():
     table = pa.Table.from_batches(batches)
     new_table = rust.build_table(batches, schema)
 
-    assert table.schema == new_table.schema
     assert table == new_table
+    assert table.schema == new_table.schema
+    assert table.schema.metadata == new_table.schema.metadata
     assert len(table.to_batches()) == len(new_table.to_batches())
 
 

Reply via email to