This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 03a837e883 Add tests for `BatchCoalescer::push_batch_with_filter`, fix 
bug (#7774)
03a837e883 is described below

commit 03a837e883323ef7e3294f0805c9e1cadd3963b8
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Jul 16 16:08:10 2025 -0400

    Add tests for `BatchCoalescer::push_batch_with_filter`, fix bug (#7774)
    
    # Which issue does this PR close?
    
    
    - Part of https://github.com/apache/arrow-rs/issues/7762
    
    
    # Rationale for this change
    
    As part of https://github.com/apache/arrow-rs/issues/7762 I want to
    optimize applying filters by adding a new code path.
    
    To ensure that works well, let's ensure the filtered code path is well
    covered with tests
    
    
    # What changes are included in this PR?
    
    1. Add tests for filtering batches with 0.01%, 1%, 10% and 90% and
    varying data types
    
    
    # Are these changes tested?
    Only tests, no functional changes
    
    
    # Are there any user-facing changes?
---
 arrow-select/src/coalesce.rs           | 236 +++++++++++++++++++++++++++++++--
 arrow-select/src/coalesce/primitive.rs |  11 +-
 2 files changed, 234 insertions(+), 13 deletions(-)

diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs
index 2360f25354..37741de3bc 100644
--- a/arrow-select/src/coalesce.rs
+++ b/arrow-select/src/coalesce.rs
@@ -342,7 +342,10 @@ impl BatchCoalescer {
 fn create_in_progress_array(data_type: &DataType, batch_size: usize) -> 
Box<dyn InProgressArray> {
     macro_rules! instantiate_primitive {
         ($t:ty) => {
-            Box::new(InProgressPrimitiveArray::<$t>::new(batch_size))
+            Box::new(InProgressPrimitiveArray::<$t>::new(
+                batch_size,
+                data_type.clone(),
+            ))
         };
     }
 
@@ -391,9 +394,11 @@ mod tests {
     use arrow_array::builder::StringViewBuilder;
     use arrow_array::cast::AsArray;
     use arrow_array::{
-        BinaryViewArray, RecordBatchOptions, StringArray, StringViewArray, 
UInt32Array,
+        BinaryViewArray, Int64Array, RecordBatchOptions, StringArray, 
StringViewArray,
+        TimestampNanosecondArray, UInt32Array,
     };
     use arrow_schema::{DataType, Field, Schema};
+    use rand::{Rng, SeedableRng};
     use std::ops::Range;
 
     #[test]
@@ -484,6 +489,98 @@ mod tests {
             .run();
     }
 
+    /// Coalesce multiple batches, 80k rows, with a 0.1% selectivity filter
+    #[test]
+    fn test_coalesce_filtered_001() {
+        let mut filter_builder = RandomFilterBuilder {
+            num_rows: 8000,
+            selectivity: 0.001,
+            seed: 0,
+        };
+
+        // add 10 batches of 8000 rows each
+        // 80k rows, selecting 0.1% means 80 rows
+        // not exactly 80 as the rows are random;
+        let mut test = Test::new();
+        for _ in 0..10 {
+            test = test
+                .with_batch(multi_column_batch(0..8000))
+                .with_filter(filter_builder.next_filter())
+        }
+        test.with_batch_size(15)
+            .with_expected_output_sizes(vec![15, 15, 15, 13])
+            .run();
+    }
+
+    /// Coalesce multiple batches, 80k rows, with a 1% selectivity filter
+    #[test]
+    fn test_coalesce_filtered_01() {
+        let mut filter_builder = RandomFilterBuilder {
+            num_rows: 8000,
+            selectivity: 0.01,
+            seed: 0,
+        };
+
+        // add 10 batches of 8000 rows each
+        // 80k rows, selecting 1% means 800 rows
+        // not exactly 800 as the rows are random;
+        let mut test = Test::new();
+        for _ in 0..10 {
+            test = test
+                .with_batch(multi_column_batch(0..8000))
+                .with_filter(filter_builder.next_filter())
+        }
+        test.with_batch_size(128)
+            .with_expected_output_sizes(vec![128, 128, 128, 128, 128, 128, 15])
+            .run();
+    }
+
+    /// Coalesce multiple batches, 80k rows, with a 10% selectivity filter
+    #[test]
+    fn test_coalesce_filtered_1() {
+        let mut filter_builder = RandomFilterBuilder {
+            num_rows: 8000,
+            selectivity: 0.1,
+            seed: 0,
+        };
+
+        // add 10 batches of 8000 rows each
+        // 80k rows, selecting 10% means 8000 rows
+        // not exactly 800 as the rows are random;
+        let mut test = Test::new();
+        for _ in 0..10 {
+            test = test
+                .with_batch(multi_column_batch(0..8000))
+                .with_filter(filter_builder.next_filter())
+        }
+        test.with_batch_size(1024)
+            .with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 
1024, 1024, 840])
+            .run();
+    }
+
+    /// Coalesce multiple batches, 8k rows, with a 90% selectivity filter
+    #[test]
+    fn test_coalesce_filtered_90() {
+        let mut filter_builder = RandomFilterBuilder {
+            num_rows: 800,
+            selectivity: 0.90,
+            seed: 0,
+        };
+
+        // add 10 batches of 800 rows each
+        // 8k rows, selecting 99% means 7200 rows
+        // not exactly 7200 as the rows are random;
+        let mut test = Test::new();
+        for _ in 0..10 {
+            test = test
+                .with_batch(multi_column_batch(0..800))
+                .with_filter(filter_builder.next_filter())
+        }
+        test.with_batch_size(1024)
+            .with_expected_output_sizes(vec![1024, 1024, 1024, 1024, 1024, 
1024, 1024, 13])
+            .run();
+    }
+
     #[test]
     fn test_coalesce_non_null() {
         Test::new()
@@ -862,6 +959,11 @@ mod tests {
     struct Test {
         /// Batches to feed to the coalescer.
         input_batches: Vec<RecordBatch>,
+        /// Filters to apply to the corresponding input batches.
+        ///
+        /// If there are no filters for the input batches, the batch will be
+        /// pushed as is.
+        filters: Vec<BooleanArray>,
         /// The schema. If not provided, the first batch's schema is used.
         schema: Option<SchemaRef>,
         /// Expected output sizes of the resulting batches
@@ -874,6 +976,7 @@ mod tests {
         fn default() -> Self {
             Self {
                 input_batches: vec![],
+                filters: vec![],
                 schema: None,
                 expected_output_sizes: vec![],
                 target_batch_size: 1024,
@@ -898,6 +1001,12 @@ mod tests {
             self
         }
 
+        /// Extend the filters with `filter`
+        fn with_filter(mut self, filter: BooleanArray) -> Self {
+            self.filters.push(filter);
+            self
+        }
+
         /// Extends the input batches with `batches`
         fn with_batches(mut self, batches: impl IntoIterator<Item = 
RecordBatch>) -> Self {
             self.input_batches.extend(batches);
@@ -920,23 +1029,29 @@ mod tests {
         ///
         /// Returns the resulting output batches
         fn run(self) -> Vec<RecordBatch> {
+            let expected_output = self.expected_output();
+            let schema = self.schema();
+
             let Self {
                 input_batches,
-                schema,
+                filters,
+                schema: _,
                 target_batch_size,
                 expected_output_sizes,
             } = self;
 
-            let schema = schema.unwrap_or_else(|| input_batches[0].schema());
-
-            // create a single large input batch for output comparison
-            let single_input_batch = concat_batches(&schema, 
&input_batches).unwrap();
+            let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
 
             let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 
target_batch_size);
 
-            let had_input = input_batches.iter().any(|b| b.num_rows() > 0);
+            // feed input batches and filters to the coalescer
+            let mut filters = filters.into_iter();
             for batch in input_batches {
-                coalescer.push_batch(batch).unwrap();
+                if let Some(filter) = filters.next() {
+                    coalescer.push_batch_with_filter(batch, &filter).unwrap();
+                } else {
+                    coalescer.push_batch(batch).unwrap();
+                }
             }
             assert_eq!(schema, coalescer.schema());
 
@@ -976,7 +1091,7 @@ mod tests {
             for (i, (expected_size, batch)) in iter {
                 // compare the contents of the batch after normalization (using
                 // `==` compares the underlying memory layout too)
-                let expected_batch = single_input_batch.slice(starting_idx, 
*expected_size);
+                let expected_batch = expected_output.slice(starting_idx, 
*expected_size);
                 let expected_batch = normalize_batch(expected_batch);
                 let batch = normalize_batch(batch.clone());
                 assert_eq!(
@@ -988,6 +1103,36 @@ mod tests {
             }
             output_batches
         }
+
+        /// Return the expected output schema. If not overridden by 
`with_schema`, it
+        /// returns the schema of the first input batch.
+        fn schema(&self) -> SchemaRef {
+            self.schema
+                .clone()
+                .unwrap_or_else(|| Arc::clone(&self.input_batches[0].schema()))
+        }
+
+        /// Returns the expected output as a single `RecordBatch`
+        fn expected_output(&self) -> RecordBatch {
+            let schema = self.schema();
+            if self.filters.is_empty() {
+                return concat_batches(&schema, &self.input_batches).unwrap();
+            }
+
+            let mut filters = self.filters.iter();
+            let filtered_batches = self
+                .input_batches
+                .iter()
+                .map(|batch| {
+                    if let Some(filter) = filters.next() {
+                        filter_record_batch(batch, filter).unwrap()
+                    } else {
+                        batch.clone()
+                    }
+                })
+                .collect::<Vec<_>>();
+            concat_batches(&schema, &filtered_batches).unwrap()
+        }
     }
 
     /// Return a RecordBatch with a UInt32Array with the specified range and
@@ -1063,6 +1208,77 @@ mod tests {
         RecordBatch::try_new(Arc::clone(&schema), 
vec![Arc::new(array)]).unwrap()
     }
 
+    /// Return a RecordBatch of 100 rows
+    fn multi_column_batch(range: Range<i32>) -> RecordBatch {
+        let int64_array = Int64Array::from_iter(range.clone().map(|v| {
+            if v % 5 == 0 {
+                None
+            } else {
+                Some(v as i64)
+            }
+        }));
+        let string_view_array = 
StringViewArray::from_iter(range.clone().map(|v| {
+            if v % 5 == 0 {
+                None
+            } else if v % 7 == 0 {
+                Some(format!("This is a string longer than 12 bytes{v}"))
+            } else {
+                Some(format!("Short {v}"))
+            }
+        }));
+        let string_array = StringArray::from_iter(range.clone().map(|v| {
+            if v % 11 == 0 {
+                None
+            } else {
+                Some(format!("Value {v}"))
+            }
+        }));
+        let timestamp_array = 
TimestampNanosecondArray::from_iter(range.map(|v| {
+            if v % 3 == 0 {
+                None
+            } else {
+                Some(v as i64 * 1000) // simulate a timestamp in milliseconds
+            }
+        }))
+        .with_timezone("America/New_York");
+
+        RecordBatch::try_from_iter(vec![
+            ("int64", Arc::new(int64_array) as ArrayRef),
+            ("stringview", Arc::new(string_view_array) as ArrayRef),
+            ("string", Arc::new(string_array) as ArrayRef),
+            ("timestamp", Arc::new(timestamp_array) as ArrayRef),
+        ])
+        .unwrap()
+    }
+
+    /// Return a boolean array that filters out randomly selected rows
+    /// from the input batch with a `selectivity`.
+    ///
+    /// For example a `selectivity` of 0.1 will filter out
+    /// 90% of the rows.
+    #[derive(Debug)]
+    struct RandomFilterBuilder {
+        num_rows: usize,
+        selectivity: f64,
+        /// seed for random number generator, increases by one each time
+        /// `next_filter` is called
+        seed: u64,
+    }
+    impl RandomFilterBuilder {
+        /// Build the next filter with the current seed and increment the seed
+        /// by one.
+        fn next_filter(&mut self) -> BooleanArray {
+            assert!(self.selectivity >= 0.0 && self.selectivity <= 1.0);
+            let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed);
+            self.seed += 1;
+            BooleanArray::from_iter(
+                (0..self.num_rows)
+                    .map(|_| rng.random_bool(self.selectivity))
+                    .map(Some),
+            )
+        }
+    }
+
     /// Returns the named column as a StringViewArray
     fn col_as_string_view<'b>(name: &str, batch: &'b RecordBatch) -> &'b 
StringViewArray {
         batch
diff --git a/arrow-select/src/coalesce/primitive.rs 
b/arrow-select/src/coalesce/primitive.rs
index 8355f24f31..85b653357b 100644
--- a/arrow-select/src/coalesce/primitive.rs
+++ b/arrow-select/src/coalesce/primitive.rs
@@ -19,13 +19,15 @@ use crate::coalesce::InProgressArray;
 use arrow_array::cast::AsArray;
 use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
 use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
-use arrow_schema::ArrowError;
+use arrow_schema::{ArrowError, DataType};
 use std::fmt::Debug;
 use std::sync::Arc;
 
 /// InProgressArray for [`PrimitiveArray`]
 #[derive(Debug)]
 pub(crate) struct InProgressPrimitiveArray<T: ArrowPrimitiveType> {
+    /// Data type of the array
+    data_type: DataType,
     /// The current source, if any
     source: Option<ArrayRef>,
     /// the target batch size (and thus size for views allocation)
@@ -38,8 +40,9 @@ pub(crate) struct InProgressPrimitiveArray<T: 
ArrowPrimitiveType> {
 
 impl<T: ArrowPrimitiveType> InProgressPrimitiveArray<T> {
     /// Create a new `InProgressPrimitiveArray`
-    pub(crate) fn new(batch_size: usize) -> Self {
+    pub(crate) fn new(batch_size: usize, data_type: DataType) -> Self {
         Self {
+            data_type,
             batch_size,
             source: None,
             nulls: NullBufferBuilder::new(batch_size),
@@ -95,7 +98,9 @@ impl<T: ArrowPrimitiveType + Debug> InProgressArray for 
InProgressPrimitiveArray
         let nulls = self.nulls.finish();
         self.nulls = NullBufferBuilder::new(self.batch_size);
 
-        let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), 
nulls)?;
+        let array = PrimitiveArray::<T>::try_new(ScalarBuffer::from(values), 
nulls)?
+            // preserve timezone / precision+scale if applicable
+            .with_data_type(self.data_type.clone());
         Ok(Arc::new(array))
     }
 }

Reply via email to