This is an automated email from the ASF dual-hosted git repository. mgrigorov pushed a commit to branch serde-driven-schema-aware-deserialization in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit e76ff35f722c88f9a6dfd754561314ce08038dee Author: Martin Tzvetanov Grigorov <[email protected]> AuthorDate: Mon Jul 21 13:50:42 2025 +0300 Add impl for Schema::Boolean Signed-off-by: Martin Tzvetanov Grigorov <[email protected]> --- avro/src/de.rs | 4 +- avro/src/de_schema.rs | 379 ++++++++++++++++++++++++++++++++++++++++++++++ avro/src/lib.rs | 1 + avro/src/reader.rs | 11 ++ avro/src/ser_schema.rs | 8 +- avro/src/writer.rs | 3 +- avro/tests/avro-rs-226.rs | 11 +- 7 files changed, 406 insertions(+), 11 deletions(-) diff --git a/avro/src/de.rs b/avro/src/de.rs index 8d0c640..c5d8a86 100644 --- a/avro/src/de.rs +++ b/avro/src/de.rs @@ -16,7 +16,7 @@ // under the License. //! Logic for serde-compatible deserialization. -use crate::{Error, bytes::DE_BYTES_BORROWED, types::Value}; +use crate::{AvroResult, Error, bytes::DE_BYTES_BORROWED, types::Value}; use serde::{ Deserialize, de::{self, DeserializeSeed, Deserializer as _, Visitor}, @@ -755,7 +755,7 @@ impl<'de> de::Deserializer<'de> for StringDeserializer { /// /// This conversion can fail if the structure of the `Value` does not match the /// structure expected by `D`. -pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> Result<D, Error> { +pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> AvroResult<D> { let de = Deserializer::new(value); D::deserialize(&de) } diff --git a/avro/src/de_schema.rs b/avro/src/de_schema.rs new file mode 100644 index 0000000..63326ba --- /dev/null +++ b/avro/src/de_schema.rs @@ -0,0 +1,379 @@ +use crate::schema::{NamesRef, Namespace}; +use crate::{Error, Schema}; +use serde::de::Visitor; +use std::io::Read; + +pub struct SchemaAwareReadDeserializer<'s, R: Read> { + reader: &'s mut R, + root_schema: &'s Schema, + names: &'s NamesRef<'s>, + enclosing_namespace: Namespace, +} + +impl<'s, R: Read> SchemaAwareReadDeserializer<'s, R> { + pub(crate) fn new( + reader: &'s mut R, + root_schema: &'s Schema, + names: &'s NamesRef<'s>, + enclosing_namespace: Namespace, + ) -> Self { + Self { + reader, + root_schema, + names, + enclosing_namespace, + } + } +} + +impl<'de, R: Read> serde::de::Deserializer<'de> for SchemaAwareReadDeserializer<'de, R> { + type Error = Error; + + fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + // Implement the deserialization logic here + unimplemented!() + } + + fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + let schema = self.root_schema; + let mut this = self; + (&mut this).deserialize_bool_with_schema(visitor, schema) + } + + fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_unit_struct<V>( + self, + _name: &'static str, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_newtype_struct<V>( + self, + _name: &'static str, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_tuple_struct<V>( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_struct<V>( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_enum<V>( + self, + _name: &'static str, + _variants: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + todo!() + } +} + +impl<'s, R: Read> SchemaAwareReadDeserializer<'s, R> { + fn deserialize_bool_with_schema<'de, V>( + &mut self, + visitor: V, + schema: &Schema, + ) -> Result<V::Value, Error> + where + V: Visitor<'de>, + { + let create_error = |cause: &str| Error::SerializeValueWithSchema { + // TODO: DeserializeValueWithSchema + value_type: "bool", + value: format!("Cause: {cause}"), + schema: Box::new(schema.clone()), + }; + + match schema { + Schema::Boolean => { + let mut buf = [0; 1]; + self.reader + .read_exact(&mut buf) // Read a single byte + .map_err(|e| create_error(&format!("Failed to read: {e}")))?; + let value = buf[0] != 0; + visitor.visit_bool(value) + } + Schema::Union(union_schema) => { + for (_, variant_schema) in union_schema.schemas.iter().enumerate() { + match variant_schema { + Schema::Boolean => { + return self.deserialize_bool_with_schema(visitor, variant_schema); + } + _ => { /* skip */ } + } + } + Err(create_error(&format!( + "The union schema must have a boolean variant: {schema:?}" + ))) + } + unexpected => Err(create_error(&format!( + "Expected a boolean schema, found: {unexpected:?}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::reader::read_avro_datum_ref; + use crate::schema::{Schema, UnionSchema}; + use apache_avro_test_helper::TestResult; + + #[test] + fn avro_rs_226_deserialize_bool_boolean_schema() -> TestResult { + let schema = Schema::Boolean; + + for (byte, expected) in [(0, false), (1, true)] { + let mut reader: &[u8] = &[byte]; + let read: bool = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, expected); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_bool_union_boolean_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::Boolean])?); + + for (byte, expected) in [(0, false), (1, true)] { + let mut reader: &[u8] = &[byte]; + let read: bool = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, expected); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_bool_invalid_schema() -> TestResult { + let schema = Schema::Long; // Using a non-boolean schema + + let mut reader: &[u8] = &[0, 1, 2]; + match read_avro_datum_ref::<bool, &[u8]>(&schema, &mut reader) { + Err(Error::SerializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "bool"); + assert!(value.contains("Cause: Expected a boolean schema")); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_bool_union_invalid_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::Long])?); + + let mut reader: &[u8] = &[1, 2, 3]; + match read_avro_datum_ref::<bool, &[u8]>(&schema, &mut reader) { + Err(Error::SerializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "bool"); + assert!(value.contains("The union schema must have a boolean variant")); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid union schema"), + } + + Ok(()) + } +} diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 0f26246..9c9bf01 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -871,6 +871,7 @@ mod ser_schema; mod util; mod writer; +mod de_schema; pub mod headers; pub mod rabin; pub mod schema; diff --git a/avro/src/reader.rs b/avro/src/reader.rs index fb93c70..24c5b25 100644 --- a/avro/src/reader.rs +++ b/avro/src/reader.rs @@ -16,6 +16,8 @@ // under the License. //! Logic handling reading from Avro format at user level. +use crate::de_schema::SchemaAwareReadDeserializer; +use crate::schema::NamesRef; use crate::{ AvroResult, Codec, Error, decode::{decode, decode_internal}, @@ -596,6 +598,15 @@ pub fn read_marker(bytes: &[u8]) -> [u8; 16] { marker } +pub fn read_avro_datum_ref<'de, D: DeserializeOwned, R: Read>( + schema: &Schema, + reader: &mut R, +) -> AvroResult<D> { + let names: NamesRef = NamesRef::default(); + let deserializer = SchemaAwareReadDeserializer::new(reader, schema, &names, None); + D::deserialize(deserializer) +} + #[cfg(test)] mod tests { use super::*; diff --git a/avro/src/ser_schema.rs b/avro/src/ser_schema.rs index d2fde2a..7d56954 100644 --- a/avro/src/ser_schema.rs +++ b/avro/src/ser_schema.rs @@ -338,14 +338,12 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_, }; if next_field_matches { - self.serialize_next_field(&value).map_err(|e| { - Error::SerializeRecordFieldWithSchema { + self.serialize_next_field(&value) + .map_err(|e| Error::SerializeRecordFieldWithSchema { field_name: key, record_schema: Box::new(Schema::Record(self.record_schema.clone())), error: Box::new(e), - } - })?; - Ok(()) + }) } else { if self.item_count < self.record_schema.fields.len() { for i in self.item_count..self.record_schema.fields.len() { diff --git a/avro/src/writer.rs b/avro/src/writer.rs index e043425..ee91015 100644 --- a/avro/src/writer.rs +++ b/avro/src/writer.rs @@ -719,8 +719,7 @@ pub fn write_avro_datum_ref<T: Serialize, W: Write>( ) -> AvroResult<usize> { let names: HashMap<Name, &Schema> = HashMap::new(); let mut serializer = SchemaAwareWriteSerializer::new(writer, schema, &names, None); - let bytes_written = data.serialize(&mut serializer)?; - Ok(bytes_written) + data.serialize(&mut serializer) } /// Encode a compatible value (implementing the `ToAvro` trait) into Avro format, also diff --git a/avro/tests/avro-rs-226.rs b/avro/tests/avro-rs-226.rs index dacce84..f4ee037 100644 --- a/avro/tests/avro-rs-226.rs +++ b/avro/tests/avro-rs-226.rs @@ -12,9 +12,16 @@ where writer.append_ser(record)?; let bytes_written = writer.into_inner()?; - let reader = apache_avro::Reader::new(&bytes_written[..])?; + // let mut bytes_written = Cursor::new(bytes_written); + // let value = from_avro_datum(schema, &mut bytes_written, None)?; + // dbg!(&value); + // let deserialized = from_value::<T>(&value)?; + // assert_eq!(deserialized, record2); + + let reader = apache_avro::Reader::with_schema(schema, &bytes_written[..])?; for value in reader { let value = value?; + dbg!(&value); let deserialized = from_value::<T>(&value)?; assert_eq!(deserialized, record2); } @@ -46,7 +53,7 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field() -> TestResult { #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] struct T { - #[serde(skip_serializing_if = "Option::is_none")] + // #[serde(skip_serializing_if = "Option::is_none")] x: Option<i8>, y: Option<String>, z: Option<i8>,
