This is an automated email from the ASF dual-hosted git repository.
mgrigorov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/avro-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 2019363 feat: Implement support for `#[serde(flatten)]` (#359)
2019363 is described below
commit 2019363e573b95ab3c6bdda8cafa982ec052f7e9
Author: Kriskras99 <[email protected]>
AuthorDate: Thu Dec 11 21:30:34 2025 +0100
feat: Implement support for `#[serde(flatten)]` (#359)
* chore: Move all Serde related modules to the `serde` module
* feat: Implement support for `#[serde(flatten)]`
This is done by adding a `#[avro(flatten]` attribute so that the
schema (for that field) is also flattened, and by adding support
in `SchemaAwareWriteSerializer` for serializing a struct via Map
instead of Struct.
`flatten` does not work with `to_value`, as `to_value` does not have
access to the schema.
* fix: Handle duplicate fields in when flatten is used
---------
Co-authored-by: default <[email protected]>
---
avro/src/error.rs | 6 +
avro/src/lib.rs | 7 +-
avro/src/{ => serde}/de.rs | 0
avro/src/serde/mod.rs | 4 +
avro/src/{ => serde}/ser.rs | 7 +-
avro/src/{ => serde}/ser_schema.rs | 136 ++++++++++++++---
avro/src/serde/util.rs | 300 +++++++++++++++++++++++++++++++++++++
avro/src/types.rs | 6 +-
avro/src/writer.rs | 2 +-
avro_derive/src/lib.rs | 68 ++++++---
avro_derive/tests/derive.rs | 164 ++++++++++++++++++++
11 files changed, 650 insertions(+), 50 deletions(-)
diff --git a/avro/src/error.rs b/avro/src/error.rs
index a3e2cf0..bdb2055 100644
--- a/avro/src/error.rs
+++ b/avro/src/error.rs
@@ -579,6 +579,12 @@ pub enum Details {
#[error("Cannot convert a slice to Uuid: {0}")]
UuidFromSlice(#[source] uuid::Error),
+
+ #[error("Expected String for Map key when serializing a flattened struct")]
+ MapFieldExpectedString,
+
+ #[error("No key for value when serializing a map")]
+ MapNoKey,
}
#[derive(thiserror::Error, PartialEq)]
diff --git a/avro/src/lib.rs b/avro/src/lib.rs
index 853722d..f75c5a3 100644
--- a/avro/src/lib.rs
+++ b/avro/src/lib.rs
@@ -945,14 +945,12 @@
mod bigdecimal;
mod bytes;
mod codec;
-mod de;
mod decimal;
mod decode;
mod duration;
mod encode;
mod reader;
-mod ser;
-mod ser_schema;
+mod serde;
mod writer;
pub mod error;
@@ -979,7 +977,6 @@ pub use codec::xz::XzSettings;
#[cfg(feature = "zstandard")]
pub use codec::zstandard::ZstandardSettings;
pub use codec::{Codec, DeflateSettings};
-pub use de::from_value;
pub use decimal::Decimal;
pub use duration::{Days, Duration, Millis, Months};
pub use error::Error;
@@ -988,7 +985,7 @@ pub use reader::{
from_avro_datum_reader_schemata, from_avro_datum_schemata, read_marker,
};
pub use schema::{AvroSchema, Schema};
-pub use ser::to_value;
+pub use serde::{de::from_value, ser::to_value};
pub use uuid::Uuid;
pub use writer::{
GenericSingleObjectWriter, SpecificSingleObjectWriter, Writer,
WriterBuilder, to_avro_datum,
diff --git a/avro/src/de.rs b/avro/src/serde/de.rs
similarity index 100%
rename from avro/src/de.rs
rename to avro/src/serde/de.rs
diff --git a/avro/src/serde/mod.rs b/avro/src/serde/mod.rs
new file mode 100644
index 0000000..509d2e5
--- /dev/null
+++ b/avro/src/serde/mod.rs
@@ -0,0 +1,4 @@
+pub mod de;
+pub mod ser;
+pub mod ser_schema;
+mod util;
diff --git a/avro/src/ser.rs b/avro/src/serde/ser.rs
similarity index 99%
rename from avro/src/ser.rs
rename to avro/src/serde/ser.rs
index 1bc9075..d78f501 100644
--- a/avro/src/ser.rs
+++ b/avro/src/serde/ser.rs
@@ -479,7 +479,12 @@ impl ser::SerializeStructVariant for
StructVariantSerializer<'_> {
/// Interpret a serializeable instance as a `Value`.
///
/// This conversion can fail if the value is not valid as per the Avro
specification.
-/// e.g: HashMap with non-string keys
+/// e.g: `HashMap` with non-string keys.
+///
+/// This function does not work if `S` has any fields (recursively) that have
the `#[serde(flatten)]`
+/// attribute. Please use [`Writer::append_ser`] if that's the case.
+///
+/// [`Writer::append_ser`]: crate::Writer::append_ser
pub fn to_value<S: Serialize>(value: S) -> Result<Value, Error> {
let mut serializer = Serializer::default();
value.serialize(&mut serializer)
diff --git a/avro/src/ser_schema.rs b/avro/src/serde/ser_schema.rs
similarity index 96%
rename from avro/src/ser_schema.rs
rename to avro/src/serde/ser_schema.rs
index f9ee2fc..02a65bc 100644
--- a/avro/src/ser_schema.rs
+++ b/avro/src/serde/ser_schema.rs
@@ -23,9 +23,10 @@ use crate::{
encode::{encode_int, encode_long},
error::{Details, Error},
schema::{Name, NamesRef, Namespace, RecordField, RecordSchema, Schema},
+ serde::util::StringSerializer,
};
use bigdecimal::BigDecimal;
-use serde::ser;
+use serde::{Serialize, ser};
use std::{borrow::Cow, cmp::Ordering, collections::HashMap, io::Write,
str::FromStr};
const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024;
@@ -251,6 +252,8 @@ pub struct SchemaAwareWriteSerializeStruct<'a, 's, W:
Write> {
record_schema: &'s RecordSchema,
/// Fields we received in the wrong order
field_cache: HashMap<usize, Vec<u8>>,
+ /// The current field name when serializing from a map (for `flatten`
support).
+ map_field_name: Option<String>,
field_position: usize,
bytes_written: usize,
}
@@ -264,6 +267,7 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a,
's, W> {
ser,
record_schema,
field_cache: HashMap::new(),
+ map_field_name: None,
field_position: 0,
bytes_written: 0,
}
@@ -352,6 +356,11 @@ impl<'a, 's, W: Write> SchemaAwareWriteSerializeStruct<'a,
's, W> {
"There should be no more unwritten fields at this point: {:?}",
self.field_cache
);
+ debug_assert!(
+ self.map_field_name.is_none(),
+ "There should be no field name at this point: field {:?}",
+ self.map_field_name
+ );
Ok(self.bytes_written)
}
}
@@ -371,17 +380,14 @@ impl<W: Write> ser::SerializeStruct for
SchemaAwareWriteSerializeStruct<'_, '_,
.and_then(|idx| self.record_schema.fields.get(*idx));
match record_field {
- Some(field) => {
- // self.item_count += 1;
- self.serialize_next_field(field, value).map_err(|e| {
- Details::SerializeRecordFieldWithSchema {
- field_name: key.to_string(),
- record_schema:
Schema::Record(self.record_schema.clone()),
- error: Box::new(e),
- }
- .into()
- })
- }
+ Some(field) => self.serialize_next_field(field, value).map_err(|e|
{
+ Details::SerializeRecordFieldWithSchema {
+ field_name: key.to_string(),
+ record_schema: Schema::Record(self.record_schema.clone()),
+ error: Box::new(e),
+ }
+ .into()
+ }),
None => Err(Details::FieldName(String::from(key)).into()),
}
}
@@ -420,6 +426,53 @@ impl<W: Write> ser::SerializeStruct for
SchemaAwareWriteSerializeStruct<'_, '_,
}
}
+/// This implementation is used to support `#[serde(flatten)]` as that uses
SerializeMap instead of SerializeStruct.
+impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeStruct<'_, '_,
W> {
+ type Ok = usize;
+ type Error = Error;
+
+ fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ let name = key.serialize(StringSerializer)?;
+ let old = self.map_field_name.replace(name);
+ debug_assert!(
+ old.is_none(),
+ "Expected a value instead of a key: old key: {old:?}, new key:
{:?}",
+ self.map_field_name
+ );
+ Ok(())
+ }
+
+ fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ let key = self.map_field_name.take().ok_or(Details::MapNoKey)?;
+ let record_field = self
+ .record_schema
+ .lookup
+ .get(&key)
+ .and_then(|idx| self.record_schema.fields.get(*idx));
+ match record_field {
+ Some(field) => self.serialize_next_field(field, value).map_err(|e|
{
+ Details::SerializeRecordFieldWithSchema {
+ field_name: key.to_string(),
+ record_schema: Schema::Record(self.record_schema.clone()),
+ error: Box::new(e),
+ }
+ .into()
+ }),
+ None => Err(Details::FieldName(key).into()),
+ }
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ self.end()
+ }
+}
+
impl<W: Write> ser::SerializeStructVariant for
SchemaAwareWriteSerializeStruct<'_, '_, W> {
type Ok = usize;
type Error = Error;
@@ -436,6 +489,46 @@ impl<W: Write> ser::SerializeStructVariant for
SchemaAwareWriteSerializeStruct<'
}
}
+/// Map serializer that switches between Struct or Map.
+///
+/// This exists because when `#[serde(flatten)]` is used, struct fields are
serialized as a map.
+pub enum SchemaAwareWriteSerializeMapOrStruct<'a, 's, W: Write> {
+ Struct(SchemaAwareWriteSerializeStruct<'a, 's, W>),
+ Map(SchemaAwareWriteSerializeMap<'a, 's, W>),
+}
+
+impl<W: Write> ser::SerializeMap for SchemaAwareWriteSerializeMapOrStruct<'_,
'_, W> {
+ type Ok = usize;
+ type Error = Error;
+
+ fn serialize_key<T>(&mut self, key: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ match self {
+ Self::Struct(s) => s.serialize_key(key),
+ Self::Map(s) => s.serialize_key(key),
+ }
+ }
+
+ fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ match self {
+ Self::Struct(s) => s.serialize_value(value),
+ Self::Map(s) => s.serialize_value(value),
+ }
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ match self {
+ Self::Struct(s) => s.end(),
+ Self::Map(s) => s.end(),
+ }
+ }
+}
+
/// The tuple struct serializer for [`SchemaAwareWriteSerializer`].
/// [`SchemaAwareWriteSerializeTupleStruct`] can serialize to an Avro array,
record, or big-decimal.
/// When serializing to a record, fields must be provided in the correct
order, since no names are provided.
@@ -1499,7 +1592,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
&'a mut self,
len: Option<usize>,
schema: &'s Schema,
- ) -> Result<SchemaAwareWriteSerializeMap<'a, 's, W>, Error> {
+ ) -> Result<SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>, Error> {
let create_error = |cause: String| {
let len_str = len
.map(|l| format!("{l}"))
@@ -1513,15 +1606,17 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
};
match schema {
- Schema::Map(map_schema) => Ok(SchemaAwareWriteSerializeMap::new(
- self,
- map_schema.types.as_ref(),
- len,
+ Schema::Map(map_schema) =>
Ok(SchemaAwareWriteSerializeMapOrStruct::Map(
+ SchemaAwareWriteSerializeMap::new(self,
map_schema.types.as_ref(), len),
)),
+ Schema::Ref { name: ref_name } => {
+ let ref_schema = self.get_ref_schema(ref_name)?;
+ self.serialize_map_with_schema(len, ref_schema)
+ }
Schema::Union(union_schema) => {
for (i, variant_schema) in
union_schema.schemas.iter().enumerate() {
match variant_schema {
- Schema::Map(_) => {
+ Schema::Map(_) | Schema::Record(_) | Schema::Ref { ..
} => {
encode_int(i as i32, &mut *self.writer)?;
return self.serialize_map_with_schema(len,
variant_schema);
}
@@ -1532,6 +1627,9 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> {
"Expected a Map schema in {union_schema:?}"
)))
}
+ Schema::Record(record_schema) =>
Ok(SchemaAwareWriteSerializeMapOrStruct::Struct(
+ SchemaAwareWriteSerializeStruct::new(self, record_schema),
+ )),
_ => Err(create_error(format!(
"Expected Map or Union schema. Got: {schema}"
))),
@@ -1630,7 +1728,7 @@ impl<'a, 's, W: Write> ser::Serializer for &'a mut
SchemaAwareWriteSerializer<'s
type SerializeTuple = SchemaAwareWriteSerializeSeq<'a, 's, W>;
type SerializeTupleStruct = SchemaAwareWriteSerializeTupleStruct<'a, 's,
W>;
type SerializeTupleVariant = SchemaAwareWriteSerializeTupleStruct<'a, 's,
W>;
- type SerializeMap = SchemaAwareWriteSerializeMap<'a, 's, W>;
+ type SerializeMap = SchemaAwareWriteSerializeMapOrStruct<'a, 's, W>;
type SerializeStruct = SchemaAwareWriteSerializeStruct<'a, 's, W>;
type SerializeStructVariant = SchemaAwareWriteSerializeStruct<'a, 's, W>;
diff --git a/avro/src/serde/util.rs b/avro/src/serde/util.rs
new file mode 100644
index 0000000..55ea2ea
--- /dev/null
+++ b/avro/src/serde/util.rs
@@ -0,0 +1,300 @@
+use crate::{Error, error::Details};
+use serde::{
+ Serialize, Serializer,
+ ser::{
+ SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant,
SerializeTuple,
+ SerializeTupleStruct, SerializeTupleVariant,
+ },
+};
+
+/// Serialize a `T: Serialize` as a `String`.
+///
+/// An error will be returned if any other function than
[`Serializer::serialize_str`] is called.
+pub struct StringSerializer;
+
+impl Serializer for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+ type SerializeSeq = Self;
+ type SerializeTuple = Self;
+ type SerializeTupleStruct = Self;
+ type SerializeTupleVariant = Self;
+ type SerializeMap = Self;
+ type SerializeStruct = Self;
+ type SerializeStructVariant = Self;
+
+ fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
+ Ok(v.to_string())
+ }
+
+ fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_some<T>(self, _value: &T) -> Result<Self::Ok, Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok,
Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_unit_variant(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ _variant: &'static str,
+ ) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_newtype_struct<T>(
+ self,
+ _name: &'static str,
+ _value: &T,
+ ) -> Result<Self::Ok, Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_newtype_variant<T>(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ _variant: &'static str,
+ _value: &T,
+ ) -> Result<Self::Ok, Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq,
Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple,
Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_tuple_struct(
+ self,
+ _name: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeTupleStruct, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_tuple_variant(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ _variant: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeTupleVariant, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap,
Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_struct(
+ self,
+ _name: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeStruct, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_struct_variant(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ _variant: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeStructVariant, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeSeq for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeTuple for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_element<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeTupleStruct for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeTupleVariant for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_field<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeMap for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_key<T>(&mut self, _key: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn serialize_value<T>(&mut self, _value: &T) -> Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeStruct for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) ->
Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
+
+impl SerializeStructVariant for StringSerializer {
+ type Ok = String;
+ type Error = Error;
+
+ fn serialize_field<T>(&mut self, _key: &'static str, _value: &T) ->
Result<(), Self::Error>
+ where
+ T: ?Sized + Serialize,
+ {
+ Err(Details::MapFieldExpectedString.into())
+ }
+
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ Err(Details::MapFieldExpectedString.into())
+ }
+}
diff --git a/avro/src/types.rs b/avro/src/types.rs
index 4448eef..5a54c3f 100644
--- a/avro/src/types.rs
+++ b/avro/src/types.rs
@@ -2701,7 +2701,7 @@ Field with name '"b"' is not a member of the map items"#,
#[test]
fn test_avro_3460_validation_with_refs_real_struct() -> TestResult {
- use crate::ser::Serializer;
+ use crate::serde::ser::Serializer;
use serde::Serialize;
#[derive(Serialize, Clone)]
@@ -2790,7 +2790,7 @@ Field with name '"b"' is not a member of the map items"#,
}
fn avro_3674_with_or_without_namespace(with_namespace: bool) -> TestResult
{
- use crate::ser::Serializer;
+ use crate::serde::ser::Serializer;
use serde::Serialize;
let schema_str = r#"
@@ -2883,7 +2883,7 @@ Field with name '"b"' is not a member of the map items"#,
}
fn avro_3688_schema_resolution_panic(set_field_b: bool) -> TestResult {
- use crate::ser::Serializer;
+ use crate::serde::ser::Serializer;
use serde::{Deserialize, Serialize};
let schema_str = r#"{
diff --git a/avro/src/writer.rs b/avro/src/writer.rs
index 3b62b16..a1ae239 100644
--- a/avro/src/writer.rs
+++ b/avro/src/writer.rs
@@ -22,7 +22,7 @@ use crate::{
error::Details,
headers::{HeaderBuilder, RabinFingerprintHeader},
schema::{AvroSchema, Name, ResolvedOwnedSchema, ResolvedSchema, Schema},
- ser_schema::SchemaAwareWriteSerializer,
+ serde::ser_schema::SchemaAwareWriteSerializer,
types::Value,
};
use serde::Serialize;
diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs
index c447e58..2b225fb 100644
--- a/avro_derive/src/lib.rs
+++ b/avro_derive/src/lib.rs
@@ -39,6 +39,8 @@ struct FieldOptions {
rename: Option<String>,
#[darling(default)]
skip: Option<bool>,
+ #[darling(default)]
+ flatten: Option<bool>,
}
#[derive(darling::FromAttributes)]
@@ -142,26 +144,46 @@ fn get_data_struct_schema_def(
let mut record_field_exprs = vec![];
match s.fields {
syn::Fields::Named(ref a) => {
- let mut index: usize = 0;
for field in a.named.iter() {
- let mut name = field.ident.as_ref().unwrap().to_string(); //
we know everything has a name
+ let mut name = field
+ .ident
+ .as_ref()
+ .expect("Field must have a name")
+ .to_string();
if let Some(raw_name) = name.strip_prefix("r#") {
name = raw_name.to_string();
}
let field_attrs =
-
FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?;
+
FieldOptions::from_attributes(&field.attrs).map_err(darling_to_syn)?;
let doc =
preserve_optional(field_attrs.doc.or_else(||
extract_outer_doc(&field.attrs)));
match (field_attrs.rename, rename_all) {
(Some(rename), _) => {
name = rename;
}
- (None, rename_all) if !matches!(rename_all,
RenameRule::None) => {
+ (None, rename_all) if rename_all != RenameRule::None => {
name = rename_all.apply_to_field(&name);
}
_ => {}
}
- if let Some(true) = field_attrs.skip {
+ if Some(true) == field_attrs.skip {
+ continue;
+ } else if Some(true) == field_attrs.flatten {
+ // Inline the fields of the child record at runtime, as we
don't have access to
+ // the schema here.
+ let flatten_ty = &field.ty;
+ record_field_exprs.push(quote! {
+ if let
::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema {
fields, .. }) = #flatten_ty::get_schema() {
+ for mut field in fields {
+ field.position = schema_fields.len();
+ schema_fields.push(field)
+ }
+ } else {
+ panic!("Can only flatten RecordSchema, got {:?}",
#flatten_ty::get_schema())
+ }
+ });
+
+ // Don't add this field as it's been replaced by the child
record fields
continue;
}
let default_value = match field_attrs.default {
@@ -181,20 +203,18 @@ fn get_data_struct_schema_def(
};
let aliases = preserve_vec(field_attrs.alias);
let schema_expr = type_to_schema_expr(&field.ty)?;
- let position = index;
record_field_exprs.push(quote! {
- apache_avro::schema::RecordField {
- name: #name.to_string(),
- doc: #doc,
- default: #default_value,
- aliases: #aliases,
- schema: #schema_expr,
- order:
apache_avro::schema::RecordFieldOrder::Ascending,
- position: #position,
- custom_attributes: Default::default(),
- }
+ schema_fields.push(::apache_avro::schema::RecordField {
+ name: #name.to_string(),
+ doc: #doc,
+ default: #default_value,
+ aliases: #aliases,
+ schema: #schema_expr,
+ order:
::apache_avro::schema::RecordFieldOrder::Ascending,
+ position: schema_fields.len(),
+ custom_attributes: Default::default(),
+ });
});
- index += 1;
}
}
syn::Fields::Unnamed(_) => {
@@ -212,8 +232,14 @@ fn get_data_struct_schema_def(
}
let record_doc = preserve_optional(record_doc);
let record_aliases = preserve_vec(aliases);
+ // When flatten is involved, there will be more but we don't know how
many. This optimises for
+ // the most common case where there is no flatten.
+ let minimum_fields = record_field_exprs.len();
Ok(quote! {
- let schema_fields = vec![#(#record_field_exprs),*];
+ let mut schema_fields = Vec::with_capacity(#minimum_fields);
+ #(#record_field_exprs)*
+ let schema_field_set: ::std::collections::HashSet<_> =
schema_fields.iter().map(|rf| &rf.name).collect();
+ assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate
field names found: {schema_fields:?}");
let name =
apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to
parse struct name for schema {}", #full_schema_name)[..]);
let lookup: std::collections::BTreeMap<String, usize> = schema_fields
.iter()
@@ -683,7 +709,7 @@ mod tests {
match syn::parse2::<DeriveInput>(test_struct) {
Ok(mut input) => {
let schema_res = derive_avro_schema(&mut input);
- let expected_token_stream = r#"let schema_fields = vec !
[apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some
("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect
(format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec !
["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema
:: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending ,
position : 0usize , custom_att [...]
+ let expected_token_stream = r#"let mut schema_fields = Vec ::
with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema ::
RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) ,
default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid
JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () ,
"a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : ::
apache_avro :: schema :: Recor [...]
let schema_token_stream = schema_res.unwrap().to_string();
assert!(schema_token_stream.contains(expected_token_stream));
}
@@ -725,7 +751,7 @@ mod tests {
match syn::parse2::<DeriveInput>(test_struct) {
Ok(mut input) => {
let schema_res = derive_avro_schema(&mut input);
- let expected_token_stream = r#"let name = apache_avro ::
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (&
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } }
else { named_schemas . insert (name . clone () , apache_avro :: schema ::
Schema :: Ref { name : name . clone () }) ; [...]
+ let expected_token_stream = r#"let name = apache_avro ::
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (&
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } }
else { named_schemas . insert (name . clone () , apache_avro :: schema ::
Schema :: Ref { name : name . clone () }) ; [...]
let schema_token_stream = schema_res.unwrap().to_string();
assert!(schema_token_stream.contains(expected_token_stream));
}
@@ -769,7 +795,7 @@ mod tests {
match syn::parse2::<DeriveInput>(test_struct) {
Ok(mut input) => {
let schema_res = derive_avro_schema(&mut input);
- let expected_token_stream = r#"let name = apache_avro ::
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (&
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } }
else { named_schemas . insert (name . clone () , apache_avro :: schema ::
Schema :: Ref { name : name . clone () }) ; [...]
+ let expected_token_stream = r#"let name = apache_avro ::
schema :: Name :: new ("A") . expect (& format ! ("Unable to parse schema name
{}" , "A") [..]) . fully_qualified_name (enclosing_namespace) ; let
enclosing_namespace = & name . namespace ; if named_schemas . contains_key (&
name) { apache_avro :: schema :: Schema :: Ref { name : name . clone () } }
else { named_schemas . insert (name . clone () , apache_avro :: schema ::
Schema :: Ref { name : name . clone () }) ; [...]
let schema_token_stream = schema_res.unwrap().to_string();
assert!(schema_token_stream.contains(expected_token_stream));
}
diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs
index 8d92c57..6972e9a 100644
--- a/avro_derive/tests/derive.rs
+++ b/avro_derive/tests/derive.rs
@@ -1686,4 +1686,168 @@ mod test_derive {
panic!("Unexpected schema type for Foo")
}
}
+
+ #[test]
+ fn avro_rs_247_serde_flatten_support() {
+ #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+ struct Nested {
+ a: bool,
+ }
+
+ #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+ struct Foo {
+ #[serde(flatten)]
+ #[avro(flatten)]
+ nested: Nested,
+ b: i32,
+ }
+
+ let schema = r#"
+ {
+ "type":"record",
+ "name":"Foo",
+ "fields": [
+ {
+ "name":"a",
+ "type":"boolean"
+ },
+ {
+ "name":"b",
+ "type":"int"
+ }
+ ]
+ }
+ "#;
+
+ let schema = Schema::parse_str(schema).unwrap();
+ assert_eq!(schema, Foo::get_schema());
+
+ serde_assert(Foo {
+ nested: Nested { a: true },
+ b: 321,
+ });
+ }
+
+ #[test]
+ fn avro_rs_247_serde_nested_flatten_support() {
+ use apache_avro::AvroSchema;
+ use serde::{Deserialize, Serialize};
+
+ #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
+ pub struct NestedFoo {
+ one: u32,
+ }
+
+ #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
+ pub struct Foo {
+ #[serde(flatten)]
+ #[avro(flatten)]
+ nested_foo: NestedFoo,
+ }
+
+ #[derive(AvroSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
+ struct Bar {
+ foo: Foo,
+ two: u32,
+ }
+
+ let schema = r#"
+ {
+ "type":"record",
+ "name":"Bar",
+ "fields": [
+ {
+ "name":"foo",
+ "type": {
+ "type": "record",
+ "name": "Foo",
+ "fields": [
+ {
+ "name": "one",
+ "type": "long"
+ }
+ ]
+ }
+ },
+ {
+ "name":"two",
+ "type":"long"
+ }
+ ]
+ }
+ "#;
+
+ let schema = Schema::parse_str(schema).unwrap();
+ assert_eq!(schema, Bar::get_schema());
+
+ serde_assert(Bar {
+ foo: Foo {
+ nested_foo: NestedFoo { one: 42 },
+ },
+ two: 2,
+ });
+ }
+
+ #[test]
+ #[should_panic(expected = "Duplicate field names found")]
+ fn avro_rs_247_serde_flatten_support_duplicate_field_name() {
+ #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+ struct Nested {
+ a: i32,
+ }
+
+ #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+ struct Foo {
+ #[serde(flatten)]
+ #[avro(flatten)]
+ nested: Nested,
+ a: i32,
+ }
+
+ Foo::get_schema();
+ }
+
+ #[test]
+ fn avro_rs_247_serde_flatten_support_with_skip() {
+ #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+ struct Nested {
+ a: bool,
+ #[serde(skip)]
+ #[avro(skip)]
+ c: f64,
+ }
+
+ #[derive(Debug, Serialize, Deserialize, AvroSchema, Clone, PartialEq)]
+ struct Foo {
+ #[serde(flatten)]
+ #[avro(flatten)]
+ nested: Nested,
+ b: i32,
+ }
+
+ let schema = r#"
+ {
+ "type":"record",
+ "name":"Foo",
+ "fields": [
+ {
+ "name":"a",
+ "type":"boolean"
+ },
+ {
+ "name":"b",
+ "type":"int"
+ }
+ ]
+ }
+ "#;
+
+ let schema = Schema::parse_str(schema).unwrap();
+ assert_eq!(schema, Foo::get_schema());
+
+ serde_assert(Foo {
+ nested: Nested { a: true, c: 0.0 },
+ b: 321,
+ });
+ }
}