liurenjie1024 commented on code in PR #295: URL: https://github.com/apache/iceberg-rust/pull/295#discussion_r1547582654
########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec<i32>, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::T> { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec<usize>, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap<i32, usize>, +} + +fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), Review Comment: We should not panic here, but throws error. ########## crates/iceberg/src/arrow.rs: ########## @@ -20,24 +20,38 @@ use async_stream::try_stream; use futures::stream::StreamExt; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use std::collections::HashMap; use crate::io::FileIO; use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream}; -use crate::spec::SchemaRef; +use crate::spec::{Datum, PrimitiveLiteral, SchemaRef}; use crate::error::Result; +use crate::expr::{ + BinaryExpression, BoundPredicate, BoundReference, PredicateOperator, SetExpression, + UnaryExpression, +}; use crate::spec::{ ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type, }; use crate::{Error, ErrorKind}; +use arrow_arith::boolean::{and, is_not_null, is_null, not, or}; +use arrow_array::{ + BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, +}; +use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; +use bitvec::macros::internal::funty::Fundamental; +use parquet::arrow::arrow_reader::{ArrowPredicate, ArrowPredicateFn, RowFilter}; +use parquet::schema::types::{SchemaDescriptor, Type as ParquetType}; use std::sync::Arc; /// Builder to create ArrowReader pub struct ArrowReaderBuilder { batch_size: Option<usize>, file_io: FileIO, schema: SchemaRef, + predicates: Option<Vec<BoundPredicate>>, Review Comment: The `Vec` is kind of confusing to me, is it conjunction or disjunction? Since we already have `And/Or` as part of `BoundPredicate` variant, how about just `BoundPredicate`? ########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec<i32>, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::T> { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec<usize>, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap<i32, usize>, +} + +fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), + } +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box<dyn ArrowPredicate>; + type U = usize; + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); Review Comment: This maybe incorrect for nested column, I think maybe we should either return projection_mask for each leave column, or implement a general purpose flatten method for struct array. ########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec<i32>, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::T> { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec<usize>, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap<i32, usize>, +} + +fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), + } +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box<dyn ArrowPredicate>; + type U = usize; + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_not_null(column) + }, + ))), + PredicateOperator::IsNan => { + todo!("IsNan is not supported yet") + } + PredicateOperator::NotNan => { + todo!("NotNan is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported unary operator: {op}"), + )), + } + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + let literal = predicate.literal().clone(); + + match predicate.op() { + PredicateOperator::LessThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + lt(left, literal.as_ref()) + }, + ))), + PredicateOperator::LessThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + lt_eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::GreaterThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + gt(left, literal.as_ref()) + }, + ))), + PredicateOperator::GreaterThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + gt_eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::Eq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::NotEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + neq(left, literal.as_ref()) + }, + ))), + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported binary operator: {op}"), + )), + } + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + match predicate.op() { + PredicateOperator::In => { + todo!("In is not supported yet") + } + PredicateOperator::NotIn => { + todo!("NotIn is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported set operator: {op}"), + )), + } + } + + fn and(&mut self, mut predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + and(&left, &right) + }, + ))) + } + + fn or(&mut self, mut predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + or(&left, &right) + }, + ))) + } + + fn not(&mut self, mut predicate: Self::T) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let evaluated = predicate.evaluate(batch.clone())?; + not(&evaluated) + }, + ))) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::U> { + let column_idx = self.column_map.get(&reference.field().id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field id {} not found in schema", reference.field().id), + ) + })?; + + let root_col_index = self.parquet_schema.get_column_root_idx(*column_idx); + + // Find the column index in projection mask. + let column_idx = self + .columns + .iter() + .position(|&x| x == root_col_index) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Column index {} not found in schema", root_col_index), + ) + })?; + + Ok(column_idx) + } +} + +/// A visitor for bound predicates. +pub trait BoundPredicateVisitor { Review Comment: How about moving this to `expr/predicate` module? ########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec<i32>, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::T> { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec<usize>, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap<i32, usize>, +} + +fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), + } +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box<dyn ArrowPredicate>; + type U = usize; + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_not_null(column) + }, + ))), + PredicateOperator::IsNan => { + todo!("IsNan is not supported yet") + } + PredicateOperator::NotNan => { + todo!("NotNan is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported unary operator: {op}"), + )), + } + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + let literal = predicate.literal().clone(); + + match predicate.op() { + PredicateOperator::LessThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + lt(left, literal.as_ref()) + }, + ))), + PredicateOperator::LessThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + lt_eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::GreaterThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + gt(left, literal.as_ref()) + }, + ))), + PredicateOperator::GreaterThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + gt_eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::Eq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::NotEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + neq(left, literal.as_ref()) + }, + ))), + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported binary operator: {op}"), + )), + } + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + match predicate.op() { + PredicateOperator::In => { + todo!("In is not supported yet") + } + PredicateOperator::NotIn => { + todo!("NotIn is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported set operator: {op}"), + )), + } + } + + fn and(&mut self, mut predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + and(&left, &right) + }, + ))) + } + + fn or(&mut self, mut predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + or(&left, &right) + }, + ))) + } + + fn not(&mut self, mut predicate: Self::T) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let evaluated = predicate.evaluate(batch.clone())?; + not(&evaluated) + }, + ))) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::U> { + let column_idx = self.column_map.get(&reference.field().id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field id {} not found in schema", reference.field().id), + ) + })?; + + let root_col_index = self.parquet_schema.get_column_root_idx(*column_idx); + + // Find the column index in projection mask. + let column_idx = self + .columns + .iter() + .position(|&x| x == root_col_index) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Column index {} not found in schema", root_col_index), + ) + })?; + + Ok(column_idx) + } +} + +/// A visitor for bound predicates. +pub trait BoundPredicateVisitor { + /// Return type of this visitor on bound predicate. + type T; + + /// Return type of this visitor on bound reference. + type U; + + /// Visit a bound predicate. + fn visit_predicate(&mut self, predicate: &BoundPredicate) -> Result<Self::T> { Review Comment: Though it's correct, it's not typical post order visitor pattern, how about moving this travels logic into a standalone funtion, so that we can follow convention of other visitors? ########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec<i32>, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::T> { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec<usize>, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap<i32, usize>, +} + +fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), + } +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box<dyn ArrowPredicate>; + type U = usize; + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_not_null(column) + }, + ))), + PredicateOperator::IsNan => { + todo!("IsNan is not supported yet") + } + PredicateOperator::NotNan => { + todo!("NotNan is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported unary operator: {op}"), Review Comment: We should not return error, but always true. ########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { Review Comment: Could we add some ut for this? ########## crates/iceberg/src/arrow.rs: ########## @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result<Option<RowFilter>> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::<Result<Vec<_>>>() + }) + .collect::<Result<Vec<_>>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result<HashMap<i32, usize>> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec<i32>, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn or(&mut self, _predicates: Vec<Self::T>) -> Result<Self::T> { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression<BoundReference>) -> Result<Self::T> { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result<Self::T> { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec<usize>, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap<i32, usize>, +} + +fn get_arrow_datum(datum: &Datum) -> Box<dyn ArrowDatum> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), + } +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box<dyn ArrowPredicate>; + type U = usize; + + fn visit_always_true(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result<Self::T> { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression<BoundReference>) -> Result<Self::T> { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_not_null(column) + }, + ))), + PredicateOperator::IsNan => { + todo!("IsNan is not supported yet") Review Comment: We should not panic here, it should return true. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org