This is an automated email from the ASF dual-hosted git repository.

mgrigorov pushed a commit to branch serde-driven-schema-aware-deserialization
in repository https://gitbox.apache.org/repos/asf/avro-rs.git

commit 0502cc647963e2ed3e96af68c3c6d7eb8d46b65c
Author: Martin Tzvetanov Grigorov <[email protected]>
AuthorDate: Mon Jul 21 13:50:42 2025 +0300

    Add impl for Schema::Boolean
    
    Signed-off-by: Martin Tzvetanov Grigorov <[email protected]>
---
 avro/src/de.rs            |   8 +-
 avro/src/de_schema.rs     | 379 ++++++++++++++++++++++++++++++++++++++++++++++
 avro/src/lib.rs           |  13 +-
 avro/src/reader.rs        |  23 ++-
 avro/src/ser_schema.rs    |  10 +-
 avro/src/writer.rs        |  13 +-
 avro/tests/avro-rs-226.rs |  11 +-
 7 files changed, 426 insertions(+), 31 deletions(-)

diff --git a/avro/src/de.rs b/avro/src/de.rs
index 8d0c640..c849a8e 100644
--- a/avro/src/de.rs
+++ b/avro/src/de.rs
@@ -16,16 +16,16 @@
 // under the License.
 
 //! Logic for serde-compatible deserialization.
-use crate::{Error, bytes::DE_BYTES_BORROWED, types::Value};
+use crate::{bytes::DE_BYTES_BORROWED, types::Value, AvroResult, Error};
 use serde::{
-    Deserialize,
     de::{self, DeserializeSeed, Deserializer as _, Visitor},
     forward_to_deserialize_any,
+    Deserialize,
 };
 use std::{
     collections::{
-        HashMap,
         hash_map::{Keys, Values},
+        HashMap,
     },
     slice::Iter,
 };
@@ -755,7 +755,7 @@ impl<'de> de::Deserializer<'de> for StringDeserializer {
 ///
 /// This conversion can fail if the structure of the `Value` does not match the
 /// structure expected by `D`.
-pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> Result<D, 
Error> {
+pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> 
AvroResult<D> {
     let de = Deserializer::new(value);
     D::deserialize(&de)
 }
diff --git a/avro/src/de_schema.rs b/avro/src/de_schema.rs
new file mode 100644
index 0000000..63326ba
--- /dev/null
+++ b/avro/src/de_schema.rs
@@ -0,0 +1,379 @@
+use crate::schema::{NamesRef, Namespace};
+use crate::{Error, Schema};
+use serde::de::Visitor;
+use std::io::Read;
+
+pub struct SchemaAwareReadDeserializer<'s, R: Read> {
+    reader: &'s mut R,
+    root_schema: &'s Schema,
+    names: &'s NamesRef<'s>,
+    enclosing_namespace: Namespace,
+}
+
+impl<'s, R: Read> SchemaAwareReadDeserializer<'s, R> {
+    pub(crate) fn new(
+        reader: &'s mut R,
+        root_schema: &'s Schema,
+        names: &'s NamesRef<'s>,
+        enclosing_namespace: Namespace,
+    ) -> Self {
+        Self {
+            reader,
+            root_schema,
+            names,
+            enclosing_namespace,
+        }
+    }
+}
+
+impl<'de, R: Read> serde::de::Deserializer<'de> for 
SchemaAwareReadDeserializer<'de, R> {
+    type Error = Error;
+
+    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        // Implement the deserialization logic here
+        unimplemented!()
+    }
+
+    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        let schema = self.root_schema;
+        let mut this = self;
+        (&mut this).deserialize_bool_with_schema(visitor, schema)
+    }
+
+    fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_i16<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_i32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_i64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_u8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_u16<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_u64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_string<V>(self, _visitor: V) -> Result<V::Value, 
Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, 
Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, 
Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_unit_struct<V>(
+        self,
+        _name: &'static str,
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_newtype_struct<V>(
+        self,
+        _name: &'static str,
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> 
Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_tuple_struct<V>(
+        self,
+        _name: &'static str,
+        _len: usize,
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_struct<V>(
+        self,
+        _name: &'static str,
+        _fields: &'static [&'static str],
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_enum<V>(
+        self,
+        _name: &'static str,
+        _variants: &'static [&'static str],
+        _visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, 
Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+
+    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, 
Self::Error>
+    where
+        V: Visitor<'de>,
+    {
+        todo!()
+    }
+}
+
+impl<'s, R: Read> SchemaAwareReadDeserializer<'s, R> {
+    fn deserialize_bool_with_schema<'de, V>(
+        &mut self,
+        visitor: V,
+        schema: &Schema,
+    ) -> Result<V::Value, Error>
+    where
+        V: Visitor<'de>,
+    {
+        let create_error = |cause: &str| Error::SerializeValueWithSchema {
+            // TODO: DeserializeValueWithSchema
+            value_type: "bool",
+            value: format!("Cause: {cause}"),
+            schema: Box::new(schema.clone()),
+        };
+
+        match schema {
+            Schema::Boolean => {
+                let mut buf = [0; 1];
+                self.reader
+                    .read_exact(&mut buf) // Read a single byte
+                    .map_err(|e| create_error(&format!("Failed to read: 
{e}")))?;
+                let value = buf[0] != 0;
+                visitor.visit_bool(value)
+            }
+            Schema::Union(union_schema) => {
+                for (_, variant_schema) in 
union_schema.schemas.iter().enumerate() {
+                    match variant_schema {
+                        Schema::Boolean => {
+                            return self.deserialize_bool_with_schema(visitor, 
variant_schema);
+                        }
+                        _ => { /* skip */ }
+                    }
+                }
+                Err(create_error(&format!(
+                    "The union schema must have a boolean variant: {schema:?}"
+                )))
+            }
+            unexpected => Err(create_error(&format!(
+                "Expected a boolean schema, found: {unexpected:?}"
+            ))),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::reader::read_avro_datum_ref;
+    use crate::schema::{Schema, UnionSchema};
+    use apache_avro_test_helper::TestResult;
+
+    #[test]
+    fn avro_rs_226_deserialize_bool_boolean_schema() -> TestResult {
+        let schema = Schema::Boolean;
+
+        for (byte, expected) in [(0, false), (1, true)] {
+            let mut reader: &[u8] = &[byte];
+            let read: bool = read_avro_datum_ref(&schema, &mut reader)?;
+            assert_eq!(read, expected);
+        }
+        Ok(())
+    }
+
+    #[test]
+    fn avro_rs_226_deserialize_bool_union_boolean_schema() -> TestResult {
+        let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, 
Schema::Boolean])?);
+
+        for (byte, expected) in [(0, false), (1, true)] {
+            let mut reader: &[u8] = &[byte];
+            let read: bool = read_avro_datum_ref(&schema, &mut reader)?;
+            assert_eq!(read, expected);
+        }
+        Ok(())
+    }
+
+    #[test]
+    fn avro_rs_226_deserialize_bool_invalid_schema() -> TestResult {
+        let schema = Schema::Long; // Using a non-boolean schema
+
+        let mut reader: &[u8] = &[0, 1, 2];
+        match read_avro_datum_ref::<bool, &[u8]>(&schema, &mut reader) {
+            Err(Error::SerializeValueWithSchema {
+                value_type,
+                value,
+                schema,
+            }) => {
+                assert_eq!(value_type, "bool");
+                assert!(value.contains("Cause: Expected a boolean schema"));
+                assert_eq!(schema.to_string(), schema.to_string());
+            }
+            _ => panic!("Expected an error for invalid schema"),
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn avro_rs_226_deserialize_bool_union_invalid_schema() -> TestResult {
+        let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, 
Schema::Long])?);
+
+        let mut reader: &[u8] = &[1, 2, 3];
+        match read_avro_datum_ref::<bool, &[u8]>(&schema, &mut reader) {
+            Err(Error::SerializeValueWithSchema {
+                value_type,
+                value,
+                schema,
+            }) => {
+                assert_eq!(value_type, "bool");
+                assert!(value.contains("The union schema must have a boolean 
variant"));
+                assert_eq!(schema.to_string(), schema.to_string());
+            }
+            _ => panic!("Expected an error for invalid union schema"),
+        }
+
+        Ok(())
+    }
+}
diff --git a/avro/src/lib.rs b/avro/src/lib.rs
index 0f26246..f104966 100644
--- a/avro/src/lib.rs
+++ b/avro/src/lib.rs
@@ -871,6 +871,7 @@ mod ser_schema;
 mod util;
 mod writer;
 
+mod de_schema;
 pub mod headers;
 pub mod rabin;
 pub mod schema;
@@ -898,16 +899,16 @@ pub use decimal::Decimal;
 pub use duration::{Days, Duration, Millis, Months};
 pub use error::Error;
 pub use reader::{
-    GenericSingleObjectReader, Reader, SpecificSingleObjectReader, 
from_avro_datum,
-    from_avro_datum_reader_schemata, from_avro_datum_schemata, read_marker,
+    from_avro_datum, from_avro_datum_reader_schemata, 
from_avro_datum_schemata, read_marker,
+    GenericSingleObjectReader, Reader, SpecificSingleObjectReader,
 };
 pub use schema::{AvroSchema, Schema};
 pub use ser::to_value;
 pub use util::{max_allocation_bytes, set_serde_human_readable};
 pub use uuid::Uuid;
 pub use writer::{
-    GenericSingleObjectWriter, SpecificSingleObjectWriter, Writer, 
WriterBuilder, to_avro_datum,
-    to_avro_datum_schemata, write_avro_datum_ref,
+    to_avro_datum, to_avro_datum_schemata, write_avro_datum_ref, 
GenericSingleObjectWriter, SpecificSingleObjectWriter,
+    Writer, WriterBuilder,
 };
 
 #[cfg(feature = "derive")]
@@ -919,8 +920,8 @@ pub type AvroResult<T> = Result<T, Error>;
 #[cfg(test)]
 mod tests {
     use crate::{
-        Codec, Reader, Schema, Writer, from_avro_datum,
-        types::{Record, Value},
+        from_avro_datum, types::{Record, Value}, Codec, Reader, Schema,
+        Writer,
     };
     use pretty_assertions::assert_eq;
 
diff --git a/avro/src/reader.rs b/avro/src/reader.rs
index fb93c70..c75fcec 100644
--- a/avro/src/reader.rs
+++ b/avro/src/reader.rs
@@ -16,17 +16,19 @@
 // under the License.
 
 //! Logic handling reading from Avro format at user level.
+use crate::de_schema::SchemaAwareReadDeserializer;
+use crate::schema::NamesRef;
 use crate::{
-    AvroResult, Codec, Error,
-    decode::{decode, decode_internal},
-    from_value,
-    headers::{HeaderBuilder, RabinFingerprintHeader},
+    decode::{decode, decode_internal}, from_value, headers::{HeaderBuilder, 
RabinFingerprintHeader},
     schema::{
-        AvroSchema, Names, ResolvedOwnedSchema, ResolvedSchema, Schema, 
resolve_names,
-        resolve_names_with_schemata,
+        resolve_names, resolve_names_with_schemata, AvroSchema, Names, 
ResolvedOwnedSchema, ResolvedSchema,
+        Schema,
     },
     types::Value,
     util,
+    AvroResult,
+    Codec,
+    Error,
 };
 use log::warn;
 use serde::de::DeserializeOwned;
@@ -596,6 +598,15 @@ pub fn read_marker(bytes: &[u8]) -> [u8; 16] {
     marker
 }
 
+pub fn read_avro_datum_ref<'de, D: DeserializeOwned, R: Read>(
+    schema: &Schema,
+    reader: &mut R,
+) -> AvroResult<D> {
+    let names: NamesRef = NamesRef::default();
+    let deserializer = SchemaAwareReadDeserializer::new(reader, schema, 
&names, None);
+    D::deserialize(deserializer)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/avro/src/ser_schema.rs b/avro/src/ser_schema.rs
index d2fde2a..a3fe197 100644
--- a/avro/src/ser_schema.rs
+++ b/avro/src/ser_schema.rs
@@ -338,14 +338,12 @@ impl<W: Write> ser::SerializeStruct for 
SchemaAwareWriteSerializeStruct<'_, '_,
         };
 
         if next_field_matches {
-            self.serialize_next_field(&value).map_err(|e| {
-                Error::SerializeRecordFieldWithSchema {
+            self.serialize_next_field(&value)
+                .map_err(|e| Error::SerializeRecordFieldWithSchema {
                     field_name: key,
                     record_schema: 
Box::new(Schema::Record(self.record_schema.clone())),
                     error: Box::new(e),
-                }
-            })?;
-            Ok(())
+                })
         } else {
             if self.item_count < self.record_schema.fields.len() {
                 for i in self.item_count..self.record_schema.fields.len() {
@@ -1768,7 +1766,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut 
SchemaAwareWriteSerializer<'s
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::{Days, Duration, Millis, Months, decimal::Decimal, 
schema::ResolvedSchema};
+    use crate::{decimal::Decimal, schema::ResolvedSchema, Days, Duration, 
Millis, Months};
     use apache_avro_test_helper::TestResult;
     use bigdecimal::BigDecimal;
     use num_bigint::{BigInt, Sign};
diff --git a/avro/src/writer.rs b/avro/src/writer.rs
index e043425..7b256c1 100644
--- a/avro/src/writer.rs
+++ b/avro/src/writer.rs
@@ -17,12 +17,12 @@
 
 //! Logic handling writing in Avro format at user level.
 use crate::{
-    AvroResult, Codec, Error,
-    encode::{encode, encode_internal, encode_to_vec},
-    headers::{HeaderBuilder, RabinFingerprintHeader},
-    schema::{AvroSchema, Name, ResolvedOwnedSchema, ResolvedSchema, Schema},
+    encode::{encode, encode_internal, encode_to_vec}, headers::{HeaderBuilder, 
RabinFingerprintHeader}, schema::{AvroSchema, Name, ResolvedOwnedSchema, 
ResolvedSchema, Schema},
     ser_schema::SchemaAwareWriteSerializer,
     types::Value,
+    AvroResult,
+    Codec,
+    Error,
 };
 use serde::Serialize;
 use std::{
@@ -719,8 +719,7 @@ pub fn write_avro_datum_ref<T: Serialize, W: Write>(
 ) -> AvroResult<usize> {
     let names: HashMap<Name, &Schema> = HashMap::new();
     let mut serializer = SchemaAwareWriteSerializer::new(writer, schema, 
&names, None);
-    let bytes_written = data.serialize(&mut serializer)?;
-    Ok(bytes_written)
+    data.serialize(&mut serializer)
 }
 
 /// Encode a compatible value (implementing the `ToAvro` trait) into Avro 
format, also
@@ -764,7 +763,6 @@ mod tests {
 
     use super::*;
     use crate::{
-        Reader,
         decimal::Decimal,
         duration::{Days, Duration, Millis, Months},
         headers::GlueSchemaUuidHeader,
@@ -772,6 +770,7 @@ mod tests {
         schema::{DecimalSchema, FixedSchema, Name},
         types::Record,
         util::zig_i64,
+        Reader,
     };
     use pretty_assertions::assert_eq;
     use serde::{Deserialize, Serialize};
diff --git a/avro/tests/avro-rs-226.rs b/avro/tests/avro-rs-226.rs
index dacce84..f4ee037 100644
--- a/avro/tests/avro-rs-226.rs
+++ b/avro/tests/avro-rs-226.rs
@@ -12,9 +12,16 @@ where
     writer.append_ser(record)?;
     let bytes_written = writer.into_inner()?;
 
-    let reader = apache_avro::Reader::new(&bytes_written[..])?;
+    // let mut bytes_written = Cursor::new(bytes_written);
+    // let value = from_avro_datum(schema, &mut bytes_written, None)?;
+    // dbg!(&value);
+    // let deserialized = from_value::<T>(&value)?;
+    // assert_eq!(deserialized, record2);
+
+    let reader = apache_avro::Reader::with_schema(schema, &bytes_written[..])?;
     for value in reader {
         let value = value?;
+        dbg!(&value);
         let deserialized = from_value::<T>(&value)?;
         assert_eq!(deserialized, record2);
     }
@@ -46,7 +53,7 @@ fn 
avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field
 fn 
avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field() 
-> TestResult {
     #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)]
     struct T {
-        #[serde(skip_serializing_if = "Option::is_none")]
+        // #[serde(skip_serializing_if = "Option::is_none")]
         x: Option<i8>,
         y: Option<String>,
         z: Option<i8>,

Reply via email to