This is an automated email from the ASF dual-hosted git repository.

kriskras99 pushed a commit to branch feat/serde_flatten
in repository https://gitbox.apache.org/repos/asf/avro-rs.git

commit 7e403fa682d22e345f11a331af13db845d5edc00
Author: Kriskras99 <[email protected]>
AuthorDate: Thu Dec 4 14:44:13 2025 +0100

    feat: Implement support for `#[serde(flatten)]`
    
    This is done by adding a `#[avro(flatten]` attribute so that the
    schema (for that field) is also flattened, and by adding support
    in `SchemaAwareWriteSerializer` for serializing a struct via Map
    instead of Struct.
    
    `flatten` does not work with `to_value`, as `to_value` does not have
    access to the schema.
---
 avro/src/error.rs            |   3 +
 avro/src/serde/mod.rs        |   1 +
 avro/src/serde/ser.rs        |   7 +-
 avro/src/serde/ser_schema.rs | 128 ++++++++++++++++---
 avro/src/serde/util.rs       | 298 +++++++++++++++++++++++++++++++++++++++++++
 avro_derive/src/lib.rs       |  66 +++++++---
 avro_derive/tests/derive.rs  |  99 ++++++++++++++
 7 files changed, 561 insertions(+), 41 deletions(-)

diff --git a/avro/src/error.rs b/avro/src/error.rs
index a3e2cf0..187769d 100644
--- a/avro/src/error.rs
+++ b/avro/src/error.rs
@@ -579,6 +579,9 @@ pub enum Details {
 
     #[error("Cannot convert a slice to Uuid: {0}")]
     UuidFromSlice(#[source] uuid::Error),
+
+    #[error("Expected String for Map key")]
+    MapFieldExpectedString,
 }
 
 #[derive(thiserror::Error, PartialEq)]
diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs
index 67bd005..509d2e5 100644
--- a/avro/src/serde/mod.rs
+++ b/avro/src/serde/mod.rs
@@ -1,3 +1,4 @@
 pub mod de;
 pub mod ser;
 pub mod ser_schema;
+mod util;
diff --git a/avro/src/serde/ser.rs b/avro/src/serde/ser.rs
index 1bc9075..d78f501 100644
--- a/avro/src/serde/ser.rs
+++ b/avro/src/serde/ser.rs
@@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for 
StructVariantSerializer<'_> {
 /// Interpret a serializeable instance as a `Value`.
 ///
 /// This conversion can fail if the value is not valid as per the Avro 
specification.
-/// e.g: HashMap with non-string keys
+/// e.g: `HashMap` with non-string keys.
+///
+/// This function does not work if `S` has any fields (recursively) that have 
the `#[serde(flatten)]`
+/// attribute. Please use [`Writer::append_ser`] if that's the case.
+///
+/// [`Writer::append_ser`]: crate::Writer::append_ser
 pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> {
     let mut serializer = Serializer::default();
     value.serialize(&mut serializer)
diff --git a/avro/src/serde/ser_schema.rs b/avro/src/serde/ser_schema.rs
index cfb235a..c88cc4a 100644
--- a/avro/src/serde/ser_schema.rs
+++ b/avro/src/serde/ser_schema.rs
@@ -23,9 +23,10 @@ use crate::{
     encode::{encode_int, encode_long},
     error::{Details, Error},
     schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
+    serde::util::StringSerializer,
 };
 use bigdecimal::BigDecimal;
-use serde::ser;
+use serde::{Serialize, ser};
 use std::{borrow::Cow, io::Write, str::FromStr};
 
 const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -249,8 +250,10 @@ impl<W: Write> ser::SerializeMap for 
SchemaAwareWriteSerializeMap<'_, '_, W> {
 pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> {
     ser: &'a mut SchemaAwareWriteSerializer<'s, W>,
     record_schema: &'s RecordSchema,
-    /// Fields we received in the wrong order
+    /// Fields we received in the wrong order.
     field_cache: Vec<(usize, Vec<u8>)>,
+    /// The current field name when serializing from a map (for `flatten` 
support).
+    map_field_name: Option<String>,
     next_field: usize,
     bytes_written: usize,
 }
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 
's, W> {
             ser,
             record_schema,
             field_cache: Vec::new(),
+            map_field_name: None,
             next_field: 0,
             bytes_written: 0,
         }
@@ -353,6 +357,10 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 
's, W> {
             self.field_cache.is_empty(),
             "There should be no more unwritten fields at this point"
         );
+        assert!(
+            self.map_field_name.is_none(),
+            "There should be no field name at this point"
+        );
         Ok(self.bytes_written)
     }
 }
@@ -372,17 +380,14 @@ impl<W: Write> ser::SerializeStruct for 
SchemaAwareWriteSerializeStruct<'_, '_,
             .and_then(|idx| self.record_schema.fields.get(*idx));
 
         match record_field {
-            Some(field) => {
-                // self.item_count += 1;
-                self.serialize_next_field(field, value).map_err(|e| {
-                    Details::SerializeRecordFieldWithSchema {
-                        field_name: key.to_string(),
-                        record_schema: 
Schema::Record(self.record_schema.clone()),
-                        error: Box::new(e),
-                    }
-                    .into()
-                })
-            }
+            Some(field) => self.serialize_next_field(field, value).map_err(|e| 
{
+                Details::SerializeRecordFieldWithSchema {
+                    field_name: key.to_string(),
+                    record_schema: Schema::Record(self.record_schema.clone()),
+                    error: Box::new(e),
+                }
+                .into()
+            }),
             None => Err(Details::FieldName(String::from(key)).into()),
         }
     }
@@ -421,6 +426,50 @@ impl<W: Write> ser::SerializeStruct for 
SchemaAwareWriteSerializeStruct<'_, '_,
     }
 }
 
+impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, 
W> {
+    type Ok = usize;
+    type Error = Error;
+
+    fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        let name = key.serialize(StringSerializer)?;
+        assert!(
+            self.map_field_name.replace(name).is_none(),
+            "Got two keys in a row"
+        );
+        Ok(())
+    }
+
+    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        let key = self.map_field_name.take().expect("Got value without key");
+        let record_field = self
+            .record_schema
+            .lookup
+            .get(&key)
+            .and_then(|idx| self.record_schema.fields.get(*idx));
+        match record_field {
+            Some(field) => self.serialize_next_field(field, value).map_err(|e| 
{
+                Details::SerializeRecordFieldWithSchema {
+                    field_name: key.to_string(),
+                    record_schema: Schema::Record(self.record_schema.clone()),
+                    error: Box::new(e),
+                }
+                .into()
+            }),
+            None => Err(Details::FieldName(key).into()),
+        }
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        self.end()
+    }
+}
+
 impl<W: Write> ser::SerializeStructVariant for 
SchemaAwareWriteSerializeStruct<'_, '_, W> {
     type Ok = usize;
     type Error = Error;
@@ -437,6 +486,46 @@ impl<W: Write> ser::SerializeStructVariant for 
SchemaAwareWriteSerializeStruct<'
     }
 }
 
+/// Map serializer that switches between Struct or Map.
+///
+/// This exists because when `#[serde(flatten)]` is used, struct fields are 
serialized as a map.
+pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> {
+    Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>),
+    Map(SchemaAwareWriteSerializeMap<'a, 's, W>),
+}
+
+impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, 
'_, W> {
+    type Ok = usize;
+    type Error = Error;
+
+    fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        match self {
+            Self::Struct(s) => s.serialize_key(key),
+            Self::Map(s) => s.serialize_key(key),
+        }
+    }
+
+    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        match self {
+            Self::Struct(s) => s.serialize_value(value),
+            Self::Map(s) => s.serialize_value(value),
+        }
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        match self {
+            Self::Struct(s) => s.end(),
+            Self::Map(s) => s.end(),
+        }
+    }
+}
+
 /// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
 /// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, 
record, or big-decimal.
 /// When serializing to a record, fields must be provided in the correct 
order, since no names are provided.
@@ -1500,7 +1589,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
         &'a mut self,
         len: Option<usize>,
         schema: &'s Schema,
-    ) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> {
+    ) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> {
         let create_error = |cause: String| {
             let len_str = len
                 .map(|l| format!("{l}"))
@@ -1514,10 +1603,8 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
         };
 
         match schema {
-            Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new(
-                self,
-                map_schema.types.as_ref(),
-                len,
+            Schema::Map(map_schema) => 
Ok(SchemaAwareWriteSerializeMapOrStruct::Map(
+                SchemaAwareWriteSerializeMap::new(self, 
map_schema.types.as_ref(), len),
             )),
             Schema::Union(union_schema) => {
                 for (i, variant_schema) in 
union_schema.schemas.iter().enumerate() {
@@ -1533,6 +1620,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
                     "Expected a Map schema in {union_schema:?}"
                 )))
             }
+            Schema::Record(record_schema) => 
Ok(SchemaAwareWriteSerializeMapOrStruct::Struct(
+                SchemaAwareWriteSerializeStruct::new(self, record_schema),
+            )),
             _ => Err(create_error(format!(
                 "Expected Map or Union schema. Got: {schema}"
             ))),
@@ -1631,7 +1721,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut 
SchemaAwareWriteSerializer<'s
     type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>;
     type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, 
W>;
     type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, 
W>;
-    type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>;
+    type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>;
     type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>;
     type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>;
 
diff --git a/avro/src/serde/util.rs b/avro/src/serde/util.rs
new file mode 100644
index 0000000..94e6591
--- /dev/null
+++ b/avro/src/serde/util.rs
@@ -0,0 +1,298 @@
+use crate::{Error, error::Details};
+use serde::{
+    Serialize, Serializer,
+    ser::{
+        SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, 
SerializeTuple,
+        SerializeTupleStruct, SerializeTupleVariant,
+    },
+};
+
+/// Serialize a `T: Serialize` as a `String`.
+///
+/// An error will be returned if any other function than 
[`Serializer::serialize_str`] is called.
+pub struct StringSerializer;
+
+impl Serializer for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+    type SerializeSeq = Self;
+    type SerializeTuple = Self;
+    type SerializeTupleStruct = Self;
+    type SerializeTupleVariant = Self;
+    type SerializeMap = Self;
+    type SerializeStruct = Self;
+    type SerializeStructVariant = Self;
+
+    fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
+        Ok(v.to_string())
+    }
+
+    fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_unit_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+    ) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_newtype_struct<T>(
+        self,
+        _name: &'static str,
+        _value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_newtype_variant<T>(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _value: &T,
+    ) -> Result<Self::Ok, Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_tuple_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_tuple_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, 
Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStruct, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_struct_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStructVariant, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeSeq for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeTuple for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeTupleStruct for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeTupleVariant for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+
+impl SerializeMap for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_key<T>(&mut self, _key: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn serialize_value<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+impl SerializeStruct for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) -> 
Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
+impl SerializeStructVariant for StringSerializer {
+    type Ok = String;
+    type Error = Error;
+
+    fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) -> 
Result<(), Self::Error>
+    where
+        T: ?Sized + Serialize,
+    {
+        Err(Details::MapFieldExpectedString.into())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Err(Details::MapFieldExpectedString.into())
+    }
+}
diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs
index c447e58..bd0236a 100644
--- a/avro_derive/src/lib.rs
+++ b/avro_derive/src/lib.rs
@@ -39,6 +39,8 @@ struct FieldOptions {
     rename: Option<String>,
     #[darling(default)]
     skip: Option<bool>,
+    #[darling(default)]
+    flatten: Option<bool>,
 }
 
 #[derive(darling::FromAttributes)]
@@ -142,26 +144,46 @@ fn get_data_struct_schema_def(
     let mut record_field_exprs = vec![];
     match s.fields {
         syn::Fields::Named(ref a) => {
-            let mut index: usize = 0;
             for field in a.named.iter() {
-                let mut name = field.ident.as_ref().unwrap().to_string(); // 
we know everything has a name
+                let mut name = field
+                    .ident
+                    .as_ref()
+                    .expect("Field must have a name")
+                    .to_string();
                 if let Some(raw_name) = name.strip_prefix("r#") {
                     name = raw_name.to_string();
                 }
                 let field_attrs =
-                    
FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?;
+                    
FieldOptions::from_attributes(&field.attrs).map_err(darling_to_syn)?;
                 let doc =
                     preserve_optional(field_attrs.doc.or_else(|| 
extract_outer_doc(&field.attrs)));
                 match (field_attrs.rename, rename_all) {
                     (Some(rename), _) => {
                         name = rename;
                     }
-                    (None, rename_all) if !matches!(rename_all, 
RenameRule::None) => {
+                    (None, rename_all) if rename_all != RenameRule::None => {
                         name = rename_all.apply_to_field(&name);
                     }
                     _ => {}
                 }
-                if let Some(true) = field_attrs.skip {
+                if Some(true) == field_attrs.skip {
+                    continue;
+                } else if Some(true) == field_attrs.flatten {
+                    // Inline the fields of the child record at runtime, as we 
don't have access to
+                    // the schema here.
+                    let flatten_ty = &field.ty;
+                    record_field_exprs.push(quote! {
+                        if let 
::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema { 
fields, .. }) = #flatten_ty::get_schema() {
+                            for mut field in fields {
+                                field.position = schema_fields.len();
+                                schema_fields.push(field)
+                            }
+                        } else {
+                            panic!("Can only flatten RecordSchema")
+                        }
+                    });
+
+                    // Don't add this field as it's been replaced by the child 
record fields
                     continue;
                 }
                 let default_value = match field_attrs.default {
@@ -181,20 +203,18 @@ fn get_data_struct_schema_def(
                 };
                 let aliases = preserve_vec(field_attrs.alias);
                 let schema_expr = type_to_schema_expr(&field.ty)?;
-                let position = index;
                 record_field_exprs.push(quote! {
-                    apache_avro::schema::RecordField {
-                            name: #name.to_string(),
-                            doc: #doc,
-                            default: #default_value,
-                            aliases: #aliases,
-                            schema: #schema_expr,
-                            order: 
apache_avro::schema::RecordFieldOrder::Ascending,
-                            position: #position,
-                            custom_attributes: Default::default(),
-                        }
+                    schema_fields.push(::apache_avro::schema::RecordField {
+                        name: #name.to_string(),
+                        doc: #doc,
+                        default: #default_value,
+                        aliases: #aliases,
+                        schema: #schema_expr,
+                        order: 
::apache_avro::schema::RecordFieldOrder::Ascending,
+                        position: schema_fields.len(),
+                        custom_attributes: Default::default(),
+                    });
                 });
-                index += 1;
             }
         }
         syn::Fields::Unnamed(_) => {
@@ -212,8 +232,12 @@ fn get_data_struct_schema_def(
     }
     let record_doc = preserve_optional(record_doc);
     let record_aliases = preserve_vec(aliases);
+    // When flatten is involved, there will be more but we don't know how 
many. This optimises for
+    // the most common case where there is no flatten.
+    let minimum_fields = record_field_exprs.len();
     Ok(quote! {
-        let schema_fields = vec![#(#record_field_exprs),*];
+        let mut schema_fields = Vec::with_capacity(#minimum_fields);
+        #(#record_field_exprs)*
         let name = 
apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to 
parse struct name for schema {}", #full_schema_name)[..]);
         let lookup: std::collections::BTreeMap<String, usize> = schema_fields
             .iter()
@@ -683,7 +707,7 @@ mod tests {
         match syn::parse2::<DeriveInput>(test_struct) {
             Ok(mut input) => {
                 let schema_res = derive_avro_schema(&mut input);
-                let expected_token_stream = r#"let schema_fields = vec ! 
[apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some 
("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect 
(format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! 
["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema 
:: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , 
position : 0usize , custom_att [...]
+                let expected_token_stream = r#"let mut schema_fields = Vec :: 
with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema :: 
RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , 
default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid 
JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , 
"a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : :: 
apache_avro :: schema :: Recor [...]
                 let schema_token_stream = schema_res.unwrap().to_string();
                 assert!(schema_token_stream.contains(expected_token_stream));
             }
@@ -725,7 +749,7 @@ mod tests {
         match syn::parse2::<DeriveInput>(test_struct) {
             Ok(mut input) => {
                 let schema_res = derive_avro_schema(&mut input);
-                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
+                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
                 let schema_token_stream = schema_res.unwrap().to_string();
                 assert!(schema_token_stream.contains(expected_token_stream));
             }
@@ -769,7 +793,7 @@ mod tests {
         match syn::parse2::<DeriveInput>(test_struct) {
             Ok(mut input) => {
                 let schema_res = derive_avro_schema(&mut input);
-                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
+                let expected_token_stream = r#"let name = apache_avro :: 
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name 
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let 
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& 
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } 
else { named_schemas . insert (name . clone () , apache_avro :: schema :: 
Schema :: Ref { name : name . clone () }) ;  [...]
                 let schema_token_stream = schema_res.unwrap().to_string();
                 assert!(schema_token_stream.contains(expected_token_stream));
             }
diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs
index 8d92c57..3b7be71 100644
--- a/avro_derive/tests/derive.rs
+++ b/avro_derive/tests/derive.rs
@@ -1686,4 +1686,103 @@ mod test_derive {
             panic!("Unexpected schema type for Foo")
         }
     }
+
+    #[test]
+    fn avro_247_serde_flatten_support() {
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Nested {
+            a: bool,
+        }
+
+        #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+        struct Foo {
+            #[serde(flatten)]
+            #[avro(flatten)]
+            nested: Nested,
+            b: i32,
+        }
+
+        let schema = r#"
+        {
+            "type":"record",
+            "name":"Foo",
+            "fields": [
+                {
+                    "name":"a",
+                    "type":"boolean"
+                },
+                {
+                    "name":"b",
+                    "type":"int"
+                }
+            ]
+        }
+        "#;
+
+        let schema = Schema::parse_str(schema).unwrap();
+        let derived_schema = Foo::get_schema();
+        if let Schema::Record(RecordSchema { name, fields, .. }) = 
&derived_schema {
+            assert_eq!("Foo", name.fullname(None));
+            for field in fields {
+                match field.name.as_str() {
+                    "a" | "b" => (), // expected
+                    name => panic!("Unexpected field name '{name}'"),
+                }
+            }
+        } else {
+            panic!("Foo schema must be a record schema: {derived_schema:?}")
+        }
+        assert_eq!(schema, derived_schema);
+
+        serde_assert(Foo {
+            nested: Nested { a: true },
+            b: 321,
+        });
+    }
+
+    #[test]
+    fn avro_247_serde_nested_flatten_support() {
+        use apache_avro::{AvroSchema, Reader, Writer, from_value};
+        use serde::{Deserialize, Serialize};
+
+        #[derive(AvroSchema, Serialize, Deserialize, PartialEq, Debug)]
+        pub struct NestedFoo {
+            one: u32,
+        }
+
+        #[derive(AvroSchema, Debug, Serialize, Deserialize, PartialEq)]
+        pub struct Foo {
+            #[serde(flatten)]
+            #[avro(flatten)]
+            nested_foo: NestedFoo,
+        }
+
+        #[derive(AvroSchema, Serialize, Debug, Deserialize, PartialEq)]
+        struct Bar {
+            foo: Foo,
+            two: u32,
+        }
+
+        let bar = Bar {
+            foo: Foo {
+                nested_foo: NestedFoo { one: 42 },
+            },
+            two: 2,
+            // test_enum: TestEnum::B,
+        };
+
+        let schema = Bar::get_schema();
+        println!("Generated schema: {:#?}", schema);
+
+        // When appending a value, use this crate's special extension method
+        let mut writer = Writer::new(&schema, Vec::new()).unwrap();
+        writer.append_ser(&bar).unwrap();
+
+        // Check that it was correctly serialized and is deserializable
+        let encoded = writer.into_inner().unwrap();
+        let mut reader = Reader::new(&encoded[..]).unwrap();
+        let value = reader.next().unwrap().unwrap();
+        let result: Bar = from_value(&value).unwrap();
+        assert_eq!(result, bar);
+    }
 }

Reply via email to