This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/serde_flatten in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit 7e403fa682d22e345f11a331af13db845d5edc00 Author: Kriskras99 <[email protected]> AuthorDate: Thu Dec 4 14:44:13 2025 +0100 feat: Implement support for `#[serde(flatten)]` This is done by adding a `#[avro(flatten]` attribute so that the schema (for that field) is also flattened, and by adding support in `SchemaAwareWriteSerializer` for serializing a struct via Map instead of Struct. `flatten` does not work with `to_value`, as `to_value` does not have access to the schema. --- avro/src/error.rs | 3 + avro/src/serde/mod.rs | 1 + avro/src/serde/ser.rs | 7 +- avro/src/serde/ser_schema.rs | 128 ++++++++++++++++--- avro/src/serde/util.rs | 298 +++++++++++++++++++++++++++++++++++++++++++ avro_derive/src/lib.rs | 66 +++++++--- avro_derive/tests/derive.rs | 99 ++++++++++++++ 7 files changed, 561 insertions(+), 41 deletions(-) diff --git a/avro/src/error.rs b/avro/src/error.rs index a3e2cf0..187769d 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -579,6 +579,9 @@ pub enum Details { #[error("Cannot convert a slice to Uuid: {0}")] UuidFromSlice(#[source] uuid::Error), + + #[error("Expected String for Map key")] + MapFieldExpectedString, } #[derive(thiserror::Error, PartialEq)] diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs index 67bd005..509d2e5 100644 --- a/avro/src/serde/mod.rs +++ b/avro/src/serde/mod.rs @@ -1,3 +1,4 @@ pub mod de; pub mod ser; pub mod ser_schema; +mod util; diff --git a/avro/src/serde/ser.rs b/avro/src/serde/ser.rs index 1bc9075..d78f501 100644 --- a/avro/src/serde/ser.rs +++ b/avro/src/serde/ser.rs @@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> { /// Interpret a serializeable instance as a `Value`. /// /// This conversion can fail if the value is not valid as per the Avro specification. -/// e.g: HashMap with non-string keys +/// e.g: `HashMap` with non-string keys. +/// +/// This function does not work if `S` has any fields (recursively) that have the `#[serde(flatten)]` +/// attribute. Please use [`Writer::append_ser`] if that's the case. +/// +/// [`Writer::append_ser`]: crate::Writer::append_ser pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> { let mut serializer = Serializer::default(); value.serialize(&mut serializer) diff --git a/avro/src/serde/ser_schema.rs b/avro/src/serde/ser_schema.rs index cfb235a..c88cc4a 100644 --- a/avro/src/serde/ser_schema.rs +++ b/avro/src/serde/ser_schema.rs @@ -23,9 +23,10 @@ use crate::{ encode::{encode_int, encode_long}, error::{Details, Error}, schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema}, + serde::util::StringSerializer, }; use bigdecimal::BigDecimal; -use serde::ser; +use serde::{Serialize, ser}; use std::{borrow::Cow, io::Write, str::FromStr}; const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024; @@ -249,8 +250,10 @@ impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMap<'_, '_, W> { pub struct SchemaAwareWriteSerializeStruct<'a, 's, W: Write> { ser: &'a mut SchemaAwareWriteSerializer<'s, W>, record_schema: &'s RecordSchema, - /// Fields we received in the wrong order + /// Fields we received in the wrong order. field_cache: Vec<(usize, Vec<u8>)>, + /// The current field name when serializing from a map (for `flatten` support). + map_field_name: Option<String>, next_field: usize, bytes_written: usize, } @@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> { ser, record_schema, field_cache: Vec::new(), + map_field_name: None, next_field: 0, bytes_written: 0, } @@ -353,6 +357,10 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a, 's, W> { self.field_cache.is_empty(), "There should be no more unwritten fields at this point" ); + assert!( + self.map_field_name.is_none(), + "There should be no field name at this point" + ); Ok(self.bytes_written) } } @@ -372,17 +380,14 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_, .and_then(|idx| self.record_schema.fields.get(*idx)); match record_field { - Some(field) => { - // self.item_count += 1; - self.serialize_next_field(field, value).map_err(|e| { - Details::SerializeRecordFieldWithSchema { - field_name: key.to_string(), - record_schema: Schema::Record(self.record_schema.clone()), - error: Box::new(e), - } - .into() - }) - } + Some(field) => self.serialize_next_field(field, value).map_err(|e| { + Details::SerializeRecordFieldWithSchema { + field_name: key.to_string(), + record_schema: Schema::Record(self.record_schema.clone()), + error: Box::new(e), + } + .into() + }), None => Err(Details::FieldName(String::from(key)).into()), } } @@ -421,6 +426,50 @@ impl<W: Write> ser::SerializeStruct for SchemaAwareWriteSerializeStruct<'_, '_, } } +impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_, W> { + type Ok = usize; + type Error = Error; + + fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let name = key.serialize(StringSerializer)?; + assert!( + self.map_field_name.replace(name).is_none(), + "Got two keys in a row" + ); + Ok(()) + } + + fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let key = self.map_field_name.take().expect("Got value without key"); + let record_field = self + .record_schema + .lookup + .get(&key) + .and_then(|idx| self.record_schema.fields.get(*idx)); + match record_field { + Some(field) => self.serialize_next_field(field, value).map_err(|e| { + Details::SerializeRecordFieldWithSchema { + field_name: key.to_string(), + record_schema: Schema::Record(self.record_schema.clone()), + error: Box::new(e), + } + .into() + }), + None => Err(Details::FieldName(key).into()), + } + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + self.end() + } +} + impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<'_, '_, W> { type Ok = usize; type Error = Error; @@ -437,6 +486,46 @@ impl<W: Write> ser::SerializeStructVariant for SchemaAwareWriteSerializeStruct<' } } +/// Map serializer that switches between Struct or Map. +/// +/// This exists because when `#[serde(flatten)]` is used, struct fields are serialized as a map. +pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> { + Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>), + Map(SchemaAwareWriteSerializeMap<'a, 's, W>), +} + +impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_, '_, W> { + type Ok = usize; + type Error = Error; + + fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + match self { + Self::Struct(s) => s.serialize_key(key), + Self::Map(s) => s.serialize_key(key), + } + } + + fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + match self { + Self::Struct(s) => s.serialize_value(value), + Self::Map(s) => s.serialize_value(value), + } + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + match self { + Self::Struct(s) => s.end(), + Self::Map(s) => s.end(), + } + } +} + /// The tuple struct serializer for [`SchemaAwareWriteSerializer`]. /// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array, record, or big-decimal. /// When serializing to a record, fields must be provided in the correct order, since no names are provided. @@ -1500,7 +1589,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { &'a mut self, len: Option<usize>, schema: &'s Schema, - ) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> { + ) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> { let create_error = |cause: String| { let len_str = len .map(|l| format!("{l}")) @@ -1514,10 +1603,8 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { }; match schema { - Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new( - self, - map_schema.types.as_ref(), - len, + Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Map( + SchemaAwareWriteSerializeMap::new(self, map_schema.types.as_ref(), len), )), Schema::Union(union_schema) => { for (i, variant_schema) in union_schema.schemas.iter().enumerate() { @@ -1533,6 +1620,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { "Expected a Map schema in {union_schema:?}" ))) } + Schema::Record(record_schema) => Ok(SchemaAwareWriteSerializeMapOrStruct::Struct( + SchemaAwareWriteSerializeStruct::new(self, record_schema), + )), _ => Err(create_error(format!( "Expected Map or Union schema. Got: {schema}" ))), @@ -1631,7 +1721,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut SchemaAwareWriteSerializer<'s type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>; type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>; type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's, W>; - type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>; + type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>; type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>; type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>; diff --git a/avro/src/serde/util.rs b/avro/src/serde/util.rs new file mode 100644 index 0000000..94e6591 --- /dev/null +++ b/avro/src/serde/util.rs @@ -0,0 +1,298 @@ +use crate::{Error, error::Details}; +use serde::{ + Serialize, Serializer, + ser::{ + SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, + }, +}; + +/// Serialize a `T: Serialize` as a `String`. +/// +/// An error will be returned if any other function than [`Serializer::serialize_str`] is called. +pub struct StringSerializer; + +impl Serializer for StringSerializer { + type Ok = String; + type Error = Error; + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> { + Ok(v.to_string()) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_none(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_newtype_struct<T>( + self, + _name: &'static str, + _value: &T, + ) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_newtype_variant<T>( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<Self::Ok, Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result<Self::SerializeTupleStruct, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeTupleVariant, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result<Self::SerializeStruct, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeStructVariant, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeSeq for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeTuple for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeTupleStruct for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeTupleVariant for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} + +impl SerializeMap for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_key<T>(&mut self, _key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn serialize_value<T>(&mut self, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} +impl SerializeStruct for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} +impl SerializeStructVariant for StringSerializer { + type Ok = String; + type Error = Error; + + fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + Err(Details::MapFieldExpectedString.into()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Err(Details::MapFieldExpectedString.into()) + } +} diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs index c447e58..bd0236a 100644 --- a/avro_derive/src/lib.rs +++ b/avro_derive/src/lib.rs @@ -39,6 +39,8 @@ struct FieldOptions { rename: Option<String>, #[darling(default)] skip: Option<bool>, + #[darling(default)] + flatten: Option<bool>, } #[derive(darling::FromAttributes)] @@ -142,26 +144,46 @@ fn get_data_struct_schema_def( let mut record_field_exprs = vec![]; match s.fields { syn::Fields::Named(ref a) => { - let mut index: usize = 0; for field in a.named.iter() { - let mut name = field.ident.as_ref().unwrap().to_string(); // we know everything has a name + let mut name = field + .ident + .as_ref() + .expect("Field must have a name") + .to_string(); if let Some(raw_name) = name.strip_prefix("r#") { name = raw_name.to_string(); } let field_attrs = - FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?; + FieldOptions::from_attributes(&field.attrs).map_err(darling_to_syn)?; let doc = preserve_optional(field_attrs.doc.or_else(|| extract_outer_doc(&field.attrs))); match (field_attrs.rename, rename_all) { (Some(rename), _) => { name = rename; } - (None, rename_all) if !matches!(rename_all, RenameRule::None) => { + (None, rename_all) if rename_all != RenameRule::None => { name = rename_all.apply_to_field(&name); } _ => {} } - if let Some(true) = field_attrs.skip { + if Some(true) == field_attrs.skip { + continue; + } else if Some(true) == field_attrs.flatten { + // Inline the fields of the child record at runtime, as we don't have access to + // the schema here. + let flatten_ty = &field.ty; + record_field_exprs.push(quote! { + if let ::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema { fields, .. }) = #flatten_ty::get_schema() { + for mut field in fields { + field.position = schema_fields.len(); + schema_fields.push(field) + } + } else { + panic!("Can only flatten RecordSchema") + } + }); + + // Don't add this field as it's been replaced by the child record fields continue; } let default_value = match field_attrs.default { @@ -181,20 +203,18 @@ fn get_data_struct_schema_def( }; let aliases = preserve_vec(field_attrs.alias); let schema_expr = type_to_schema_expr(&field.ty)?; - let position = index; record_field_exprs.push(quote! { - apache_avro::schema::RecordField { - name: #name.to_string(), - doc: #doc, - default: #default_value, - aliases: #aliases, - schema: #schema_expr, - order: apache_avro::schema::RecordFieldOrder::Ascending, - position: #position, - custom_attributes: Default::default(), - } + schema_fields.push(::apache_avro::schema::RecordField { + name: #name.to_string(), + doc: #doc, + default: #default_value, + aliases: #aliases, + schema: #schema_expr, + order: ::apache_avro::schema::RecordFieldOrder::Ascending, + position: schema_fields.len(), + custom_attributes: Default::default(), + }); }); - index += 1; } } syn::Fields::Unnamed(_) => { @@ -212,8 +232,12 @@ fn get_data_struct_schema_def( } let record_doc = preserve_optional(record_doc); let record_aliases = preserve_vec(aliases); + // When flatten is involved, there will be more but we don't know how many. This optimises for + // the most common case where there is no flatten. + let minimum_fields = record_field_exprs.len(); Ok(quote! { - let schema_fields = vec![#(#record_field_exprs),*]; + let mut schema_fields = Vec::with_capacity(#minimum_fields); + #(#record_field_exprs)* let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]); let lookup: std::collections::BTreeMap<String, usize> = schema_fields .iter() @@ -683,7 +707,7 @@ mod tests { match syn::parse2::<DeriveInput>(test_struct) { Ok(mut input) => { let schema_res = derive_avro_schema(&mut input); - let expected_token_stream = r#"let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_att [...] + let expected_token_stream = r#"let mut schema_fields = Vec :: with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : :: apache_avro :: schema :: Recor [...] let schema_token_stream = schema_res.unwrap().to_string(); assert!(schema_token_stream.contains(expected_token_stream)); } @@ -725,7 +749,7 @@ mod tests { match syn::parse2::<DeriveInput>(test_struct) { Ok(mut input) => { let schema_res = derive_avro_schema(&mut input); - let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; [...] + let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; [...] let schema_token_stream = schema_res.unwrap().to_string(); assert!(schema_token_stream.contains(expected_token_stream)); } @@ -769,7 +793,7 @@ mod tests { match syn::parse2::<DeriveInput>(test_struct) { Ok(mut input) => { let schema_res = derive_avro_schema(&mut input); - let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; [...] + let expected_token_stream = r#"let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name {}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let enclosing_namespace = & name . namespace ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } } else { named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; [...] let schema_token_stream = schema_res.unwrap().to_string(); assert!(schema_token_stream.contains(expected_token_stream)); } diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs index 8d92c57..3b7be71 100644 --- a/avro_derive/tests/derive.rs +++ b/avro_derive/tests/derive.rs @@ -1686,4 +1686,103 @@ mod test_derive { panic!("Unexpected schema type for Foo") } } + + #[test] + fn avro_247_serde_flatten_support() { + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Nested { + a: bool, + } + + #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)] + struct Foo { + #[serde(flatten)] + #[avro(flatten)] + nested: Nested, + b: i32, + } + + let schema = r#" + { + "type":"record", + "name":"Foo", + "fields": [ + { + "name":"a", + "type":"boolean" + }, + { + "name":"b", + "type":"int" + } + ] + } + "#; + + let schema = Schema::parse_str(schema).unwrap(); + let derived_schema = Foo::get_schema(); + if let Schema::Record(RecordSchema { name, fields, .. }) = &derived_schema { + assert_eq!("Foo", name.fullname(None)); + for field in fields { + match field.name.as_str() { + "a" | "b" => (), // expected + name => panic!("Unexpected field name '{name}'"), + } + } + } else { + panic!("Foo schema must be a record schema: {derived_schema:?}") + } + assert_eq!(schema, derived_schema); + + serde_assert(Foo { + nested: Nested { a: true }, + b: 321, + }); + } + + #[test] + fn avro_247_serde_nested_flatten_support() { + use apache_avro::{AvroSchema, Reader, Writer, from_value}; + use serde::{Deserialize, Serialize}; + + #[derive(AvroSchema, Serialize, Deserialize, PartialEq, Debug)] + pub struct NestedFoo { + one: u32, + } + + #[derive(AvroSchema, Debug, Serialize, Deserialize, PartialEq)] + pub struct Foo { + #[serde(flatten)] + #[avro(flatten)] + nested_foo: NestedFoo, + } + + #[derive(AvroSchema, Serialize, Debug, Deserialize, PartialEq)] + struct Bar { + foo: Foo, + two: u32, + } + + let bar = Bar { + foo: Foo { + nested_foo: NestedFoo { one: 42 }, + }, + two: 2, + // test_enum: TestEnum::B, + }; + + let schema = Bar::get_schema(); + println!("Generated schema: {:#?}", schema); + + // When appending a value, use this crate's special extension method + let mut writer = Writer::new(&schema, Vec::new()).unwrap(); + writer.append_ser(&bar).unwrap(); + + // Check that it was correctly serialized and is deserializable + let encoded = writer.into_inner().unwrap(); + let mut reader = Reader::new(&encoded[..]).unwrap(); + let value = reader.next().unwrap().unwrap(); + let result: Bar = from_value(&value).unwrap(); + assert_eq!(result, bar); + } }
