This is an automated email from the ASF dual-hosted git repository.

mgrigorov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/avro-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 2019363  feat: Implement support for `#[serde(flatten)]` (#359)
2019363 is described below

commit 2019363e573b95ab3c6bdda8cafa982ec052f7e9
Author: Kriskras99 <[email protected]>
AuthorDate: Thu Dec 11 21:30:34 2025 +0100

    feat: Implement support for `#[serde(flatten)]` (#359)
    
    * chore: Move all Serde related modules to the `serde` module
    
    * feat: Implement support for `#[serde(flatten)]`
    
    This is done by adding a `#[avro(flatten]` attribute so that the
    schema (for that field) is also flattened, and by adding support
    in `SchemaAwareWriteSerializer` for serializing a struct via Map
    instead of Struct.
    
    `flatten` does not work with `to_value`, as `to_value` does not have
    access to the schema.
    
    * fix: Handle duplicate fields in when flatten is used
    
    ---------
    
    Co-authored-by: default <[email protected]>
---
 avro/src/error.rs                  |   6 +
 avro/src/lib.rs                    |   7 +-
 avro/src/{ => serde}/de.rs         |   0
 avro/src/serde/mod.rs              |   4 +
 avro/src/{ => serde}/ser.rs        |   7 +-
 avro/src/{ => serde}/ser_schema.rs | 136 ++++++++++++++---
 avro/src/serde/util.rs             | 300 +++++++++++++++++++++++++++++++++++++
 avro/src/types.rs                  |   6 +-
 avro/src/writer.rs                 |   2 +-
 avro_derive/src/lib.rs             |  68 ++++++---
 avro_derive/tests/derive.rs        | 164 ++++++++++++++++++++
 11 files changed, 650 insertions(+), 50 deletions(-)

diff --git a/avro/src/error.rs b/avro/src/error.rs
index a3e2cf0..bdb2055 100644
--- a/avro/src/error.rs
+++ b/avro/src/error.rs
@@ -579,6 +579,12 @@ pub enum Details {
 
     #[error("Cannot convert a slice to Uuid: {0}")]
     UuidFromSlice(#[source] uuid::Error),
+
+    #[error("Expected String for Map key when serializing a flattened struct")]
+    MapFieldExpectedString,
+
+    #[error("No key for value when serializing a map")]
+    MapNoKey,
 }
 
 #[derive(thiserror::Error, PartialEq)]
diff --git a/avro/src/lib.rs b/avro/src/lib.rs
index 853722d..f75c5a3 100644
--- a/avro/src/lib.rs
+++ b/avro/src/lib.rs
@@ -945,14 +945,12 @@
 mod bigdecimal;
 mod bytes;
 mod codec;
-mod de;
 mod decimal;
 mod decode;
 mod duration;
 mod encode;
 mod reader;
-mod ser;
-mod ser_schema;
+mod serde;
 mod writer;
 
 pub mod error;
@@ -979,7 +977,6 @@ pub use codec::xz::XzSettings;
 #[cfg(feature = "zstandard")]
 pub use codec::zstandard::ZstandardSettings;
 pub use codec::{Codec, DeflateSettings};
-pub use de::from_value;
 pub use decimal::Decimal;
 pub use duration::{Days, Duration, Millis, Months};
 pub use error::Error;
@@ -988,7 +985,7 @@ pub use reader::{
     from_avro_datum_reader_schemata, from_avro_datum_schemata, read_marker,
 };
 pub use schema::{AvroSchema, Schema};
-pub use ser::to_value;
+pub use serde::{de::from_value, ser::to_value};
 pub use uuid::Uuid;
 pub use writer::{
     GenericSingleObjectWriter, SpecificSingleObjectWriter, Writer, 
WriterBuilder, to_avro_datum,
diff --git a/avro/src/de.rs b/avro/src/serde/de.rs
similarity index 100%
rename from avro/src/de.rs
rename to avro/src/serde/de.rs
diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs
new file mode 100644
index 0000000..509d2e5
--- /dev/null
+++ b/avro/src/serde/mod.rs
@@ -0,0 +1,4 @@
+pub mod de;
+pub mod ser;
+pub mod ser_schema;
+mod util;
diff --git a/avro/src/ser.rs b/avro/src/serde/ser.rs
similarity index 99%
rename from avro/src/ser.rs
rename to avro/src/serde/ser.rs
index 1bc9075..d78f501 100644
--- a/avro/src/ser.rs
+++ b/avro/src/serde/ser.rs
@@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for 
StructVariantSerializer<'_> {
 /// Interpret a serializeable instance as a `Value`.
 ///
 /// This conversion can fail if the value is not valid as per the Avro 
specification.
-/// e.g: HashMap with non-string keys
+/// e.g: `HashMap` with non-string keys.
+///
+/// This function does not work if `S` has any fields (recursively) that have 
the `#[serde(flatten)]`
+/// attribute. Please use [`Writer::append_ser`] if that's the case.
+///
+/// [`Writer::append_ser`]: crate::Writer::append_ser
 pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> {
     let mut serializer = Serializer::default();
     value.serialize(&mut serializer)
diff --git a/avro/src/ser_schema.rs b/avro/src/serde/ser_schema.rs
similarity index 96%
rename from avro/src/ser_schema.rs
rename to avro/src/serde/ser_schema.rs
index f9ee2fc..02a65bc 100644
--- a/avro/src/ser_schema.rs
+++ b/avro/src/serde/ser_schema.rs
@@ -23,9 +23,10 @@ use crate::{
     encode::{encode_int, encode_long},
     error::{Details, Error},
     schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
+    serde::util::StringSerializer,
 };
 use bigdecimal::BigDecimal;
-use serde::ser;
+use serde::{Serialize, ser};
 use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write, 
str::FromStr};
 
 const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -251,6 +252,8 @@ pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: 
Write> {
     record_schema: &'s RecordSchema,
     /// Fields we received in the wrong order
     field_cache: HashMap<usize, Vec<u8>>,
+    /// The current field name when serializing from a map (for `flatten` 
support).
+    map_field_name: Option<String>,
     field_position: usize,
     bytes_written: usize,
 }
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 
's, W> {
             ser,
             record_schema,
             field_cache: HashMap::new(),
+            map_field_name: None,
             field_position: 0,
             bytes_written: 0,
         }
@@ -352,6 +356,11 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 
's, W> {
             "There should be no more unwritten fields at this point: {:?}",
             self.field_cache
         );
+        debug_assert!(
+            self.map_field_name.is_none(),
+            "There should be no field name at this point: field {:?}",
+            self.map_field_name
+        );
         Ok(self.bytes_written)
     }
 }
@@ -371,17 +380,14 @@ impl<W: Write> ser::SerializeStruct for 
SchemaAwareWriteSerializeStruct<'_, '_,
             .and_then(|idx| self.record_schema.fields.get(*idx));
 
         match record_field {
-            Some(field) => {
-                // self.item_count += 1;
-                self.serialize_next_field(field, value).map_err(|e| {
-                    Details::SerializeRecordFieldWithSchema {
-                        field_name: key.to_string(),
-                        record_schema: 
Schema::Record(self.record_schema.clone()),
-                        error: Box::new(e),
-                    }
-                    .into()
-                })
-            }
+            Some(field) => self.serialize_next_field(field, value).map_err(|e| 
{
+                Details::SerializeRecordFieldWithSchema {
+                    field_name: key.to_string(),
+                    record_schema: Schema::Record(self.record_schema.clone()),
+                    error: Box::new(e),
+                }
+                .into()
+            }),
             None => Err(Details::FieldName(String::from(key)).into()),
         }
     }
@@ -420,6 +426,53 @@ impl<W: Write> ser::SerializeStruct for 
SchemaAwareWriteSerializeStruct<'_, '_,
     }
 }
 
+/// This implementation is used to support `#[serde(flatten)]` as that uses 
SerializeMap instead of SerializeStruct.
+impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, 
W> {
+    type Ok = usize;
+    type Error = Error;
+
+    fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        let name = key.serialize(StringSerializer)?;
+        let old = self.map_field_name.replace(name);
+        debug_assert!(
+            old.is_none(),
+            "Expected a value instead of a key: old key: {old:?}, new key: 
{:?}",
+            self.map_field_name
+        );
+        Ok(())
+    }
+
+    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        let key = self.map_field_name.take().ok_or(Details::MapNoKey)?;
+        let record_field = self
+            .record_schema
+            .lookup
+            .get(&key)
+            .and_then(|idx| self.record_schema.fields.get(*idx));
+        match record_field {
+            Some(field) => self.serialize_next_field(field, value).map_err(|e| 
{
+                Details::SerializeRecordFieldWithSchema {
+                    field_name: key.to_string(),
+                    record_schema: Schema::Record(self.record_schema.clone()),
+                    error: Box::new(e),
+                }
+                .into()
+            }),
+            None => Err(Details::FieldName(key).into()),
+        }
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        self.end()
+    }
+}
+
 impl<W: Write> ser::SerializeStructVariant for 
SchemaAwareWriteSerializeStruct<'_, '_, W> {
     type Ok = usize;
     type Error = Error;
@@ -436,6 +489,46 @@ impl<W: Write> ser::SerializeStructVariant for 
SchemaAwareWriteSerializeStruct<'
     }
 }
 
+/// Map serializer that switches between Struct or Map.
+///
+/// This exists because when `#[serde(flatten)]` is used, struct fields are 
serialized as a map.
+pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> {
+    Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>),
+    Map(SchemaAwareWriteSerializeMap<'a, 's, W>),
+}
+
+impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, 
'_, W> {
+    type Ok = usize;
+    type Error = Error;
+
+    fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        match self {
+            Self::Struct(s) => s.serialize_key(key),
+            Self::Map(s) => s.serialize_key(key),
+        }
+    }
+
+    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        match self {
+            Self::Struct(s) => s.serialize_value(value),
+            Self::Map(s) => s.serialize_value(value),
+        }
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        match self {
+            Self::Struct(s) => s.end(),
+            Self::Map(s) => s.end(),
+        }
+    }
+}
+
 /// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
 /// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, 
record, or big-decimal.
 /// When serializing to a record, fields must be provided in the correct 
order, since no names are provided.
@@ -1499,7 +1592,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
         &'a mut self,
         len: Option<usize>,
         schema: &'s Schema,
-    ) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> {
+    ) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> {
         let create_error = |cause: String| {
             let len_str = len
                 .map(|l| format!("{l}"))
@@ -1513,15 +1606,17 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
         };
 
         match schema {
-            Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new(
-                self,
-                map_schema.types.as_ref(),
-                len,
+            Schema::Map(map_schema) => 
Ok(SchemaAwareWriteSerializeMapOrStruct::Map(
+                SchemaAwareWriteSerializeMap::new(self, 
map_schema.types.as_ref(), len),
             )),
+            Schema::Ref { name: ref_name } => {
+                let ref_schema = self.get_ref_schema(ref_name)?;
+                self.serialize_map_with_schema(len, ref_schema)
+            }
             Schema::Union(union_schema) => {
                 for (i, variant_schema) in 
union_schema.schemas.iter().enumerate() {
                     match variant_schema {
-                        Schema::Map(_) => {
+                        Schema::Map(_) | Schema::Record(_) | Schema::Ref { .. 
} => {
                             encode_int(i as i32, &mut *self.writer)?;
                             return self.serialize_map_with_schema(len, 
variant_schema);
                         }
@@ -1532,6 +1627,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
                     "Expected a Map schema in {union_schema:?}"
                 )))
             }
+            Schema::Record(record_schema) => 
Ok(SchemaAwareWriteSerializeMapOrStruct::Struct(
+                SchemaAwareWriteSerializeStruct::new(self, record_schema),
+            )),
             _ => Err(create_error(format!(
                 "Expected Map or Union schema. Got: {schema}"
             ))),
@@ -1630,7 +1728,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut 
SchemaAwareWriteSerializer<'s
     type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>;
     type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, 
W>;
     type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, 
W>;
-    type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>;
+    type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>;
     type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>;
     type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>;
 
diff --git a/avro/src/serde/util.rs b/avro/src/serde/util.rs
new file mode 100644
index 0000000..55ea2ea
--- /dev/null
+++ b/avro/src/serde/util.rs
@@ -0,0 +1,300 @@
+use crate::{Error, error::Details};
+use serde::{
+    Serialize, Serializer,
+    ser::{
+        SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, 
SerializeTuple,
+        SerializeTupleStruct, SerializeTupleVariant,
+    },
+};
+
+/// Serialize a `T: Serialize` as a `String`.
+///
+/// An error will be returned if any other function than 
[`Serializer::serialize_str`] is called.
+pub struct StringSerializer;
+
+impl Serializer for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+    type SerializeSeq = Self;
+    type SerializeTuple = Self;
+    type SerializeTupleStruct = Self;
+    type SerializeTupleVariant = Self;
+    type SerializeMap = Self;
+    type SerializeStruct = Self;
+    type SerializeStructVariant = Self;
+
+    fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
+        Ok(v.to_string())
+    }
+
+    fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_unit_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+    ) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_newtype_struct<T>(
+        self,
+        _name: &'static str,
+        _value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_newtype_variant<T>(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_tuple_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_tuple_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStruct, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_struct_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStructVariant, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeSeq for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeTuple for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeTupleStruct for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeTupleVariant for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeMap for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_key<T>(&mut self, _key: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_value<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeStruct for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) -> 
Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeStructVariant for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) -> 
Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
diff --git a/avro/src/types.rs b/avro/src/types.rs
index 4448eef..5a54c3f 100644
--- a/avro/src/types.rs
+++ b/avro/src/types.rs
@@ -2701,7 +2701,7 @@ Field with name '"b"' is not a member of the map items"#,
 
     #[test]
     fn test_avro_3460_validation_with_refs_real_struct() -> TestResult {
-        use crate::ser::Serializer;
+        use crate::serde::ser::Serializer;
         use serde::Serialize;
 
         #[derive(Serialize, Clone)]
@@ -2790,7 +2790,7 @@ Field with name '"b"' is not a member of the map items"#,
     }
 
     fn avro_3674_with_or_without_namespace(with_namespace: bool) -> TestResult 
{
-        use crate::ser::Serializer;
+        use crate::serde::ser::Serializer;
         use serde::Serialize;
 
         let schema_str = r#"
@@ -2883,7 +2883,7 @@ Field with name '"b"' is not a member of the map items"#,
     }
 
     fn avro_3688_schema_resolution_panic(set_field_b: bool) -> TestResult {
-        use crate::ser::Serializer;
+        use crate::serde::ser::Serializer;
         use serde::{Deserialize, Serialize};
 
         let schema_str = r#"{
diff --git a/avro/src/writer.rs b/avro/src/writer.rs
index 3b62b16..a1ae239 100644
--- a/avro/src/writer.rs
+++ b/avro/src/writer.rs
@@ -22,7 +22,7 @@ use crate::{
     error::Details,
     headers::{HeaderBuilder, RabinFingerprintHeader},
     schema::{AvroSchema, Name, ResolvedOwnedSchema, ResolvedSchema, Schema},
-    ser_schema::SchemaAwareWriteSerializer,
+    serde::ser_schema::SchemaAwareWriteSerializer,
     types::Value,
 };
 use serde::Serialize;
diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs
index c447e58..2b225fb 100644
--- a/avro_derive/src/lib.rs
+++ b/avro_derive/src/lib.rs
@@ -39,6 +39,8 @@ struct FieldOptions {
     rename: Option<String>,
     #[darling(default)]
     skip: Option<bool>,
+    #[darling(default)]
+    flatten: Option<bool>,
 }
 
 #[derive(darling::FromAttributes)]
@@ -142,26 +144,46 @@ fn get_data_struct_schema_def(
     let mut record_field_exprs = vec![];
     match s.fields {
         syn::Fields::Named(ref a) => {
-            let mut index: usize = 0;
             for field in a.named.iter() {
-                let mut name = field.ident.as_ref().unwrap().to_string(); // 
we know everything has a name
+                let mut name = field
+                    .ident
+                    .as_ref()
+                    .expect("Field must have a name")
+                    .to_string();
                 if let Some(raw_name) = name.strip_prefix("r#") {
                     name = raw_name.to_string();
                 }
                 let field_attrs =
-                    
FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?;
+                    
FieldOptions::from_attributes(&field.attrs).map_err(darling_to_syn)?;
                 let doc =
                     preserve_optional(field_attrs.doc.or_else(|| 
extract_outer_doc(&field.attrs)));
                 match (field_attrs.rename, rename_all) {
                     (Some(rename), _) => {
                         name = rename;
                     }
-                    (None, rename_all) if !matches!(rename_all, 
RenameRule::None) => {
+                    (None, rename_all) if rename_all != RenameRule::None => {
                         name = rename_all.apply_to_field(&name);
                     }
                     _ => {}
                 }
-                if let Some(true) = field_attrs.skip {
+                if Some(true) == field_attrs.skip {
+                    continue;
+                } else if Some(true) == field_attrs.flatten {
+                    // Inline the fields of the child record at runtime, as we 
don't have access to
+                    // the schema here.
+                    let flatten_ty = &field.ty;
+                    record_field_exprs.push(quote! {
+                        if let 
::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema { 
fields, .. }) = #flatten_ty::get_schema() {
+                            for mut field in fields {
+                                field.position = schema_fields.len();
+                                schema_fields.push(field)
+                            }
+                        } else {
+                            panic!("Can only flatten RecordSchema, got {:?}", 
#flatten_ty::get_schema())
+                        }
+                    });
+
+                    // Don't add this field as it's been replaced by the child 
record fields
                     continue;
                 }
                 let default_value = match field_attrs.default {
@@ -181,20 +203,18 @@ fn get_data_struct_schema_def(
                 };
                 let aliases = preserve_vec(field_attrs.alias);
                 let schema_expr = type_to_schema_expr(&field.ty)?;
-                let position = index;
                 record_field_exprs.push(quote! {
-                    apache_avro::schema::RecordField {
-                            name: #name.to_string(),
-                            doc: #doc,
-                            default: #default_value,
-                            aliases: #aliases,
-                            schema: #schema_expr,
-                            order: 
apache_avro::schema::RecordFieldOrder::Ascending,
-                            position: #position,
-                            custom_attributes: Default::default(),
-                        }
+                    schema_fields.push(::apache_avro::schema::RecordField {
+                        name: #name.to_string(),
+                        doc: #doc,
+                        default: #default_value,
+                        aliases: #aliases,
+                        schema: #schema_expr,
+                        order: 
::apache_avro::schema::RecordFieldOrder::Ascending,
+                        position: schema_fields.len(),
+                        custom_attributes: Default::default(),
+                    });
                 });
-                index += 1;
             }
         }
         syn::Fields::Unnamed(_) => {
@@ -212,8 +232,14 @@ fn get_data_struct_schema_def(
     }
     let record_doc = preserve_optional(record_doc);
     let record_aliases = preserve_vec(aliases);
+    // When flatten is involved, there will be more but we don't know how 
many. This optimises for
+    // the most common case where there is no flatten.
+    let minimum_fields = record_field_exprs.len();
     Ok(quote! {
-        let schema_fields = vec![#(#record_field_exprs),*];
+        let mut schema_fields = Vec::with_capacity(#minimum_fields);
+        #(#record_field_exprs)*
+        let schema_field_set: ::std::collections::HashSet<_> = 
schema_fields.iter().map(|rf| &rf.name).collect();
+        assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate 
field names found: {schema_fields:?}");
         let name = 
apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to 
parse struct name for schema {}", #full_schema_name)[..]);
         let lookup: std::collections::BTreeMap<String, usize> = schema_fields
             .iter()
@@ -683,7 +709,7 @@ mod tests {
         match syn::parse2::<DeriveInput>(test_struct) {
             Ok(mut input) => {
                 let schema_res = derive_avro_schema(&mut input);
-                let expected_token_stream = r#"let schema_fields = vec ! 
[apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some 
("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect 
(format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! 
["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema 
:: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , 
position : 0usize , custom_att [...]
+                let expected_token_stream = r#"let mut schema_fields = Vec :: 
with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema :: 
RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , 
default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid 
JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , 
"a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : :: 
apache_avro :: schema :: Recor [...]
                 let schema_token_stream = schema_res.unwrap().to_string();
                 assert!(schema_token_stream.contains(expected_token_stream));
             }
@@ -725,7 +751,7 @@ mod tests {
         match syn::parse2::<DeriveInput>(test_struct) {
             Ok(mut input) => {
                 let schema_res = derive_avro_schema(&mut input);
-                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
+                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
                 let schema_token_stream = schema_res.unwrap().to_string();
                 assert!(schema_token_stream.contains(expected_token_stream));
             }
@@ -769,7 +795,7 @@ mod tests {
         match syn::parse2::<DeriveInput>(test_struct) {
             Ok(mut input) => {
                 let schema_res = derive_avro_schema(&mut input);
-                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
+                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
                 let schema_token_stream = schema_res.unwrap().to_string();
                 assert!(schema_token_stream.contains(expected_token_stream));
             }
diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs
index 8d92c57..6972e9a 100644
--- a/avro_derive/tests/derive.rs
+++ b/avro_derive/tests/derive.rs
@@ -1686,4 +1686,168 @@ mod test_derive {
             panic!("Unexpected schema type for Foo")
         }
     }
+
+    #[test]
+    fn avro_rs_247_serde_flatten_support() {
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Nested {
+            a: bool,
+        }
+
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Foo {
+            #[serde(flatten)]
+            #[avro(flatten)]
+            nested: Nested,
+            b: i32,
+        }
+
+        let schema = r#"
+        {
+            "type":"record",
+            "name":"Foo",
+            "fields": [
+                {
+                    "name":"a",
+                    "type":"boolean"
+                },
+                {
+                    "name":"b",
+                    "type":"int"
+                }
+            ]
+        }
+        "#;
+
+        let schema = Schema::parse_str(schema).unwrap();
+        assert_eq!(schema, Foo::get_schema());
+
+        serde_assert(Foo {
+            nested: Nested { a: true },
+            b: 321,
+        });
+    }
+
+    #[test]
+    fn avro_rs_247_serde_nested_flatten_support() {
+        use apache_avro::AvroSchema;
+        use serde::{Deserialize, Serialize};
+
+        #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
+        pub struct NestedFoo {
+            one: u32,
+        }
+
+        #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
+        pub struct Foo {
+            #[serde(flatten)]
+            #[avro(flatten)]
+            nested_foo: NestedFoo,
+        }
+
+        #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
+        struct Bar {
+            foo: Foo,
+            two: u32,
+        }
+
+        let schema = r#"
+        {
+            "type":"record",
+            "name":"Bar",
+            "fields": [
+                {
+                    "name":"foo",
+                    "type": {
+                        "type": "record",
+                        "name": "Foo",
+                        "fields": [
+                            {
+                                "name": "one",
+                                "type": "long"
+                            }
+                        ]
+                    }
+                },
+                {
+                    "name":"two",
+                    "type":"long"
+                }
+            ]
+        }
+        "#;
+
+        let schema = Schema::parse_str(schema).unwrap();
+        assert_eq!(schema, Bar::get_schema());
+
+        serde_assert(Bar {
+            foo: Foo {
+                nested_foo: NestedFoo { one: 42 },
+            },
+            two: 2,
+        });
+    }
+
+    #[test]
+    #[should_panic(expected = "Duplicate field names found")]
+    fn avro_rs_247_serde_flatten_support_duplicate_field_name() {
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Nested {
+            a: i32,
+        }
+
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Foo {
+            #[serde(flatten)]
+            #[avro(flatten)]
+            nested: Nested,
+            a: i32,
+        }
+
+        Foo::get_schema();
+    }
+
+    #[test]
+    fn avro_rs_247_serde_flatten_support_with_skip() {
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Nested {
+            a: bool,
+            #[serde(skip)]
+            #[avro(skip)]
+            c: f64,
+        }
+
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Foo {
+            #[serde(flatten)]
+            #[avro(flatten)]
+            nested: Nested,
+            b: i32,
+        }
+
+        let schema = r#"
+        {
+            "type":"record",
+            "name":"Foo",
+            "fields": [
+                {
+                    "name":"a",
+                    "type":"boolean"
+                },
+                {
+                    "name":"b",
+                    "type":"int"
+                }
+            ]
+        }
+        "#;
+
+        let schema = Schema::parse_str(schema).unwrap();
+        assert_eq!(schema, Foo::get_schema());
+
+        serde_assert(Foo {
+            nested: Nested { a: true, c: 0.0 },
+            b: 321,
+        });
+    }
 }


Reply via email to