viirya commented on code in PR #295:
URL: https://github.com/apache/iceberg-rust/pull/295#discussion_r1583629653
##########
crates/iceberg/src/arrow/reader.rs:
##########
@@ -186,4 +221,637 @@ impl ArrowReader {
Ok(ProjectionMask::leaves(parquet_schema, indices))
}
}
+
+ fn get_row_filter(
+ &self,
+ parquet_schema: &SchemaDescriptor,
+ collector: &CollectFieldIdVisitor,
+ ) -> Result<Option<RowFilter>> {
+ if let Some(predicates) = &self.predicates {
+ let field_id_map = build_field_id_map(parquet_schema)?;
+
+ let column_indices = collector
+ .field_ids
+ .iter()
+ .map(|field_id| {
+ field_id_map.get(field_id).cloned().ok_or_else(|| {
+ Error::new(ErrorKind::DataInvalid, "Field id not found
in schema")
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ // Convert BoundPredicates to ArrowPredicates
+ let mut converter = PredicateConverter {
+ columns: &column_indices,
+ projection_mask: ProjectionMask::leaves(parquet_schema,
column_indices.clone()),
+ parquet_schema,
+ column_map: &field_id_map,
+ };
+ let arrow_predicate = visit(&mut converter, predicates)?;
+ Ok(Some(RowFilter::new(vec![arrow_predicate])))
+ } else {
+ Ok(None)
+ }
+ }
+}
+
+/// Build the map of field id to Parquet column index in the schema.
+fn build_field_id_map(parquet_schema: &SchemaDescriptor) ->
Result<HashMap<i32, usize>> {
+ let mut column_map = HashMap::new();
+ for (idx, field) in parquet_schema.columns().iter().enumerate() {
+ let field_type = field.self_type();
+ match field_type {
+ ParquetType::PrimitiveType { basic_info, .. } => {
+ if !basic_info.has_id() {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Leave column {:?} in schema doesn't have field
id",
+ field_type
+ ),
+ ));
+ }
+ column_map.insert(basic_info.id(), idx);
+ }
+ ParquetType::GroupType { .. } => {
+ return Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!(
+ "Leave column in schema should be primitive type but
got {:?}",
+ field_type
+ ),
+ ));
+ }
+ };
+ }
+
+ Ok(column_map)
+}
+
+/// A visitor to collect field ids from bound predicates.
+struct CollectFieldIdVisitor {
+ field_ids: Vec<i32>,
+}
+
+impl BoundPredicateVisitor for CollectFieldIdVisitor {
+ type T = ();
+
+ fn always_true(&mut self) -> Result<Self::T> {
+ Ok(())
+ }
+
+ fn always_false(&mut self) -> Result<Self::T> {
+ Ok(())
+ }
+
+ fn and(&mut self, _lhs: Self::T, _rhs: Self::T) -> Result<Self::T> {
+ Ok(())
+ }
+
+ fn or(&mut self, _lhs: Self::T, _rhs: Self::T) -> Result<Self::T> {
+ Ok(())
+ }
+
+ fn not(&mut self, _inner: Self::T) -> Result<Self::T> {
+ Ok(())
+ }
+
+ fn is_null(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn not_null(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn is_nan(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn not_nan(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn less_than(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn less_than_or_eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn greater_than(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn greater_than_or_eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn not_eq(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn starts_with(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn not_starts_with(
+ &mut self,
+ reference: &BoundReference,
+ _literal: &Datum,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn r#in(
+ &mut self,
+ reference: &BoundReference,
+ _literals: &FnvHashSet<Datum>,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+
+ fn not_in(
+ &mut self,
+ reference: &BoundReference,
+ _literals: &FnvHashSet<Datum>,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ self.field_ids.push(reference.field().id);
+ Ok(())
+ }
+}
+
+/// A visitor to convert Iceberg bound predicates to Arrow predicates.
+struct PredicateConverter<'a> {
+ /// The leaf column indices used in the predicates.
+ pub columns: &'a Vec<usize>,
+ /// The projection mask for the Arrow predicates.
+ pub projection_mask: ProjectionMask,
+ /// The Parquet schema descriptor.
+ pub parquet_schema: &'a SchemaDescriptor,
+ /// The map between field id and leaf column index in Parquet schema.
+ pub column_map: &'a HashMap<i32, usize>,
+}
+
+impl PredicateConverter<'_> {
+ /// When visiting a bound reference, we return the projection mask for the
leaf column
+ /// which is used to project the column in the record batch.
+ fn bound_reference(&mut self, reference: &BoundReference) ->
Result<ProjectionMask> {
+ // The leaf column's index in Parquet schema.
+ let column_idx =
self.column_map.get(&reference.field().id).ok_or_else(|| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ format!("Field id {} not found in schema",
reference.field().id),
+ )
+ })?;
+
+ // Find the column index in projection mask.
+ let column_idx = self
+ .columns
+ .iter()
+ .position(|&x| x == *column_idx)
+ .ok_or_else(|| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ format!("Column index {} not found in schema",
*column_idx),
+ )
+ })?;
+
+ Ok(ProjectionMask::leaves(
+ self.parquet_schema,
+ vec![self.columns[column_idx]],
+ ))
+ }
+}
+
+fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send>> {
+ match datum.literal() {
+ PrimitiveLiteral::Boolean(value) =>
Ok(Box::new(BooleanArray::new_scalar(*value))),
+ PrimitiveLiteral::Int(value) =>
Ok(Box::new(Int32Array::new_scalar(*value))),
+ PrimitiveLiteral::Long(value) =>
Ok(Box::new(Int64Array::new_scalar(*value))),
+ PrimitiveLiteral::Float(value) =>
Ok(Box::new(Float32Array::new_scalar(value.as_f32()))),
+ PrimitiveLiteral::Double(value) =>
Ok(Box::new(Float64Array::new_scalar(value.as_f64()))),
+ l => Err(Error::new(
+ ErrorKind::DataInvalid,
+ format!("Unsupported literal type: {:?}", l),
+ )),
+ }
+}
+
+/// Recursively get the leaf column from the record batch. Assume that the
nested columns in
+/// struct is projected to a single column.
+fn get_leaf_column(column: &ArrayRef) -> std::result::Result<ArrayRef,
ArrowError> {
+ match column.data_type() {
+ DataType::Struct(fields) => {
+ if fields.len() != 1 {
+ return Err(ArrowError::SchemaError(
+ "Struct column should have only one field after projection"
+ .parse()
+ .unwrap(),
+ ));
+ }
+ let struct_array =
column.as_any().downcast_ref::<StructArray>().unwrap();
+ get_leaf_column(struct_array.column(0))
+ }
+ _ => Ok(column.clone()),
+ }
+}
+
+impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {
+ type T = Box<dyn ArrowPredicate>;
+
+ fn always_true(&mut self) -> Result<Self::T> {
+ Ok(Box::new(ArrowPredicateFn::new(
+ self.projection_mask.clone(),
+ |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])),
+ )))
+ }
+
+ fn always_false(&mut self) -> Result<Self::T> {
+ Ok(Box::new(ArrowPredicateFn::new(
+ self.projection_mask.clone(),
+ |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])),
+ )))
+ }
+
+ fn and(&mut self, mut lhs: Self::T, mut rhs: Self::T) -> Result<Self::T> {
+ Ok(Box::new(ArrowPredicateFn::new(
+ self.projection_mask.clone(),
+ move |batch| {
+ let left = lhs.evaluate(batch.clone())?;
+ let right = rhs.evaluate(batch)?;
+ and(&left, &right)
+ },
+ )))
+ }
+
+ fn or(&mut self, mut lhs: Self::T, mut rhs: Self::T) -> Result<Self::T> {
+ Ok(Box::new(ArrowPredicateFn::new(
+ self.projection_mask.clone(),
+ move |batch| {
+ let left = lhs.evaluate(batch.clone())?;
+ let right = rhs.evaluate(batch)?;
+ or(&left, &right)
+ },
+ )))
+ }
+
+ fn not(&mut self, mut inner: Self::T) -> Result<Self::T> {
+ Ok(Box::new(ArrowPredicateFn::new(
+ self.projection_mask.clone(),
+ move |batch| {
+ let pred_ret = inner.evaluate(batch)?;
+ not(&pred_ret)
+ },
+ )))
+ }
+
+ fn is_null(
+ &mut self,
+ reference: &BoundReference,
+ _predicate: &BoundPredicate,
+ ) -> Result<Self::T> {
+ let projected_mask = self.bound_reference(reference)?;
Review Comment:
Based on what I searched and the kindly reply on the issue, I think there is
no way to do nested projection on RecordBatch currently.
To implement the feature in arrow-rs might block this. I tend to finish
top-level column only in this PR.
WDYT, @liurenjie1024 ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]