This is an automated email from the ASF dual-hosted git repository. kriskras99 pushed a commit to branch feat/avro_schema_default in repository https://gitbox.apache.org/repos/asf/avro-rs.git
commit a8b671b596e7aa957e824514b2223b306dbc622e Author: Kriskras99 <[email protected]> AuthorDate: Sun Feb 8 22:49:11 2026 +0100 wip: Automatic defaults for derive --- avro/src/error.rs | 111 +++++++++++- avro/src/schema/mod.rs | 128 ++++++++++---- avro/src/schema/record/field.rs | 343 ++++++++++++++++++++++++++++++++----- avro/src/serde/derive.rs | 68 +++++++- avro_derive/src/attributes/avro.rs | 6 +- avro_derive/src/attributes/mod.rs | 54 +++++- avro_derive/src/lib.rs | 36 +++- avro_derive/tests/derive.rs | 2 + 8 files changed, 658 insertions(+), 90 deletions(-) diff --git a/avro/src/error.rs b/avro/src/error.rs index ea81125..74f210a 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -48,6 +48,10 @@ impl Error { pub fn into_details(self) -> Details { *self.details } + + pub fn into_boxed_details(self) -> Box<Details> { + self.details + } } impl From<Details> for Error { @@ -306,11 +310,110 @@ pub enum Details { #[error("Unions cannot contain more than one named schema with the same name: {0}")] GetUnionDuplicateNamedSchemas(String), - #[error("One union type {0:?} must match the `default`'s value type {1:?}")] - GetDefaultUnion(SchemaKind, ValueKind), + #[error( + "`default` of field {field_name} in {record_name} is a {value} and does not match any of the union schemas ({schemas:?})" + )] + GetDefaultUnion { + field_name: String, + record_name: String, + schemas: Vec<SchemaKind>, + value: &'static str, + }, + + #[error( + "`default` of field {field_name} in {record_name} must be {expected_ty} but got {actual_ty}" + )] + GetDefaultRecordField { + field_name: String, + record_name: String, + expected_ty: &'static str, + actual_ty: String, + }, + + #[error( + "`default` of field {field_name} in {record_name} has length {default_len} but Fixed length is {fixed_len}" + )] + GetDefaultFixedWrongLength { + field_name: String, + record_name: String, + fixed_len: usize, + default_len: usize, + }, + + #[error( + "`default` of field {field_name} in {record_name} has the character '{character}' that is larger than 0xFF in UTF-16" + )] + GetDefaultTooLargeChar { + field_name: String, + record_name: String, + character: char, + }, + + #[error( + "`default` of field {field_name} in {record_name} has the wrong item at position {position}:\n{inner:?}" + )] + GetDefaultArray { + field_name: String, + record_name: String, + position: usize, + inner: Box<Details>, + }, + + #[error( + "`default` of field {field_name} in {record_name} is missing the field {missing_field}" + )] + GetDefaultRecordMissingField { + field_name: String, + record_name: String, + missing_field: String, + }, + + #[error( + "`default` of field {field_name} in {record_name} has the wrong field type for field {offending_field}:\n{inner:?}" + )] + GetDefaultRecordWrongField { + field_name: String, + record_name: String, + offending_field: String, + inner: Box<Details>, + }, + + #[error( + "`default` of field {field_name} in {record_name} has an unknown field {offending_field}" + )] + GetDefaultRecordUnknownField { + field_name: String, + record_name: String, + offending_field: String, + }, + + #[error( + "`default` of field {field_name} in {record_name} has the wrong value for key \"{key}\":\n{inner:?}" + )] + GetDefaultMap { + field_name: String, + record_name: String, + key: String, + inner: Box<Details>, + }, + + #[error( + "`default` of field {field_name} in {record_name} has an unknown enum variant \"{variant}\"" + )] + GetDefaultUnknownEnumVariant { + field_name: String, + record_name: String, + variant: String, + }, - #[error("`default`'s value type of field {0:?} in {1:?} must be {2:?}")] - GetDefaultRecordField(String, String, String), + #[error( + "`default` of field {field_name} in {record_name} has an unexpected value \"{unexpected}\"" + )] + GetDefaultFloatDoubleString { + field_name: String, + record_name: String, + unexpected: String, + }, #[error("JSON value {0} claims to be u64 but cannot be converted")] GetU64FromJson(serde_json::Number), diff --git a/avro/src/schema/mod.rs b/avro/src/schema/mod.rs index e2d67a3..4d8efc8 100644 --- a/avro/src/schema/mod.rs +++ b/avro/src/schema/mod.rs @@ -3767,11 +3767,12 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f1".to_string(), - "ns.record1".to_string(), - r#""int""#.to_string(), - ) + let expected = Details::GetDefaultRecordField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + actual_ty: "string".to_string(), + expected_ty: "number", + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -3809,12 +3810,12 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f1".to_string(), - "ns.record1".to_string(), - r#"{"name":"ns.record2","type":"record","fields":[{"name":"f1_1","type":"int"}]}"# - .to_string(), - ) + let expected = Details::GetDefaultRecordField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + actual_ty: "string".to_string(), + expected_ty: "object", + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -3847,11 +3848,11 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f1".to_string(), - "ns.record1".to_string(), - r#"{"name":"ns.enum1","type":"enum","symbols":["a","b","c"]}"#.to_string(), - ) + let expected = Details::GetDefaultUnknownEnumVariant { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + variant: "invalid".to_string(), + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -3884,11 +3885,12 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f1".to_string(), - "ns.record1".to_string(), - r#"{"name":"ns.fixed1","type":"fixed","size":3}"#.to_string(), - ) + let expected = Details::GetDefaultRecordField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + actual_ty: "number".to_string(), + expected_ty: "string", + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -3918,11 +3920,12 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f1".to_string(), - "ns.record1".to_string(), - r#"{"type":"array","items":"int"}"#.to_string(), - ) + let expected = Details::GetDefaultRecordField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + actual_ty: "string".to_string(), + expected_ty: "array", + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -3952,11 +3955,12 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f1".to_string(), - "ns.record1".to_string(), - r#"{"type":"map","values":"string"}"#.to_string(), - ) + let expected = Details::GetDefaultRecordField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + actual_ty: "string".to_string(), + expected_ty: "object", + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -3997,11 +4001,17 @@ mod tests { ] } "#; - let expected = Details::GetDefaultRecordField( - "f2".to_string(), - "ns.record1".to_string(), - r#""ns.record2""#.to_string(), - ) + let expected = Details::GetDefaultRecordWrongField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + offending_field: "f1_1".to_string(), + inner: Box::new(Details::GetDefaultRecordField { + field_name: "f1".to_string(), + record_name: "ns.record1".to_string(), + actual_ty: "boolean".to_string(), + expected_ty: "number", + }), + } .to_string(); let result = Schema::parse_str(schema_str); assert!(result.is_err()); @@ -5128,4 +5138,50 @@ mod tests { ); Ok(()) } + + #[test] + fn avro_rs_xxx_schema_defaults() -> TestResult { + let _schema = Schema::parse_str( + r#"{ + "type": "record", + "name": "defaults", + "fields": [ + {"name": "null", "type": "null", "default": null}, + {"name": "boolean", "type": "boolean", "default": true}, + {"name": "int", "type": "int", "default": 1 }, + {"name": "long", "type": "long", "default": 1 }, + {"name": "float", "type": "float", "default": 1.1 }, + {"name": "double", "type": "double", "default": 1.1 }, + {"name": "bytes", "type": "bytes", "default": "\u00FF" }, + {"name": "string", "type": "string", "default": "foo" }, + {"name": "record", "type": { + "type": "record", + "name": "innner", + "fields": [{"name": "a", "type": "int"}] + }, "default": {"a": 1}}, + {"name": "enum", "type": { + "type": "enum", + "name": "bar", + "symbols": ["FOO"] + }, "default": "FOO"}, + {"name": "array", "type": { + "type": "array", + "name": "items", + "items": "int" + }, "default": [1] }, + {"name": "map", "type": { + "type": "map", + "name": "values", + "values": "int" + }, "default": {"a": 1}}, + {"name": "fixed", "type": { + "type": "fixed", + "name": "one", + "size": 1 + }, "default": "\u00FF"} + ] + }"#, + )?; + Ok(()) + } } diff --git a/avro/src/schema/record/field.rs b/avro/src/schema/record/field.rs index 6e70cba..08c2774 100644 --- a/avro/src/schema/record/field.rs +++ b/avro/src/schema/record/field.rs @@ -18,9 +18,9 @@ use crate::AvroResult; use crate::error::Details; use crate::schema::{ - Documentation, Name, Names, Parser, RecordSchemaParseLocation, Schema, SchemaKind, + DecimalSchema, Documentation, InnerDecimalSchema, Name, Names, Parser, + RecordSchemaParseLocation, Schema, UuidSchema, }; -use crate::types; use crate::util::MapHelper; use crate::validator::validate_record_field_name; use serde::ser::SerializeMap; @@ -94,7 +94,7 @@ impl RecordField { &name, &enclosing_record.fullname(None), parser.get_parsed_schemas(), - &default, + default.as_ref(), )?; let aliases = field.get("aliases").and_then(|aliases| { @@ -130,50 +130,313 @@ impl RecordField { field_name: &str, record_name: &str, names: &Names, - default: &Option<Value>, + default: Option<&Value>, ) -> AvroResult<()> { - if let Some(value) = default { - let avro_value = types::Value::from(value.clone()); - match field_schema { - Schema::Union(union_schema) => { - let schemas = &union_schema.schemas; - let resolved = schemas.iter().any(|schema| { - avro_value - .to_owned() - .resolve_internal(schema, names, &schema.namespace(), &None) - .is_ok() - }); + fn expected_type(schema: &Schema) -> &'static str { + match schema { + Schema::Null => "null", + Schema::Boolean => "boolean", + Schema::Int + | Schema::Long + | Schema::Date + | Schema::TimeMillis + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => "integer", + Schema::Float | Schema::Double => "integer | string", + Schema::Bytes + | Schema::String + | Schema::Enum(_) + | Schema::Fixed(_) + | Schema::Decimal(_) + | Schema::BigDecimal + | Schema::Uuid(_) + | Schema::Duration(_) => "string", + Schema::Array(_) => "array", + Schema::Map(_) | Schema::Record(_) => "object", + Schema::Union(_) => panic!("this function should not be called for unions"), + Schema::Ref { .. } => panic!("this function should not be called for references"), + } + } + fn value_kind(value: &Value) -> &'static str { + match value { + Value::Null => "null", + Value::Bool(_) => "boolean", + Value::Number(_) => "number", + Value::String(_) => "string", + Value::Array(_) => "array", + Value::Object(_) => "object", + } + } - if !resolved { - let schema: Option<&Schema> = schemas.first(); - return match schema { - Some(first_schema) => Err(Details::GetDefaultUnion( - SchemaKind::from(first_schema), - types::ValueKind::from(avro_value), - ) - .into()), - None => Err(Details::EmptyUnion.into()), - }; + if let Some(value) = default { + match (value, field_schema) { + (_, Schema::Union(union_schema)) => { + for schema in union_schema.variants() { + if Self::resolve_default_value( + schema, + field_name, + record_name, + names, + default, + ) + .is_ok() + { + return Ok(()); + } } + Err(Details::GetDefaultUnion { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + schemas: union_schema + .variants() + .into_iter() + .map(Into::into) + .collect(), + value: value_kind(value), + } + .into()) } - _ => { - let resolved = avro_value - .resolve_internal(field_schema, names, &field_schema.namespace(), &None) - .is_ok(); - - if !resolved { - return Err(Details::GetDefaultRecordField( - field_name.to_string(), - record_name.to_string(), - field_schema.canonical_form(), - ) - .into()); + (_, Schema::Ref { name }) => { + let schema = names + .get(name) + .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?; + Self::resolve_default_value(schema, field_name, record_name, names, default) + } + (Value::Null, Schema::Null) => Ok(()), + (Value::Null, _) => Err(Details::GetDefaultRecordField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + expected_ty: expected_type(field_schema), + actual_ty: "null".to_string(), + } + .into()), + (Value::Bool(_), Schema::Boolean) => Ok(()), + (Value::Bool(_), _) => Err(Details::GetDefaultRecordField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + expected_ty: expected_type(field_schema), + actual_ty: "bool".to_string(), + } + .into()), + ( + Value::Number(_), + Schema::Int + | Schema::Long + | Schema::Float + | Schema::Double + | Schema::Date + | Schema::TimeMillis + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos, + ) => Ok(()), + (Value::Number(_), _) => Err(Details::GetDefaultRecordField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + expected_ty: expected_type(field_schema), + actual_ty: "integer".to_string(), + } + .into()), + ( + Value::String(string), + Schema::Bytes + | Schema::String + | Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Bytes, + .. + }) + | Schema::BigDecimal + | Schema::Uuid(UuidSchema::String | UuidSchema::Bytes), + ) => { + // TODO: This assumes that every character should be decoded as UTF-16 but only + // uses the least significant byte of the character. So if a user wants a + // default for Bytes(2) that is all zero the correct JSON string would be + // "\u0000\u0000". Is this the correct way to do this? As the specification + // does not specify. + if let Some(character) = string.chars().find(|c| u32::from(*c) > 0xFF) { + Err(Details::GetDefaultTooLargeChar { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + character, + })? + } else { + Ok(()) } } - }; + ( + Value::String(string), + Schema::Fixed(fixed) + | Schema::Decimal(DecimalSchema { + inner: InnerDecimalSchema::Fixed(fixed), + .. + }) + | Schema::Uuid(UuidSchema::Fixed(fixed)) + | Schema::Duration(fixed), + ) => { + // TODO: This assumes that every character should be decoded as UTF-16 but only + // uses the least significant byte of the character. So if a user wants a + // default for Fixed(2) that is all zero the correct JSON string would be + // "\u0000\u0000". Is this the correct way to do this? As the specification + // does not specify. + let mut count = 0; + for character in string.chars() { + count += 1; + if u32::from(character) > 0xFF { + Err(Details::GetDefaultTooLargeChar { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + character, + })? + } + } + if count == fixed.size { + Ok(()) + } else { + Err(Details::GetDefaultFixedWrongLength { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + fixed_len: fixed.size, + default_len: count, + } + .into()) + } + } + (Value::String(string), Schema::Enum(enum_schema)) => { + if enum_schema.symbols.iter().find(|v| *v == string).is_none() { + Err(Details::GetDefaultUnknownEnumVariant { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + variant: string.clone(), + } + .into()) + } else { + Ok(()) + } + } + (Value::String(string), Schema::Float | Schema::Double) => match string.as_str() { + r#""Infinity""# | r#""-Infinity""# | r#""INF""# | r#""-INF""# | r#""NAN""# => { + Ok(()) + } + _ => Err(Details::GetDefaultFloatDoubleString { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + unexpected: string.clone(), + } + .into()), + }, + (Value::String(_), _) => Err(Details::GetDefaultRecordField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + expected_ty: expected_type(field_schema), + actual_ty: "string".to_string(), + } + .into()), + (Value::Array(array_default), Schema::Array(array_schema)) => { + for (position, value) in array_default.iter().enumerate() { + if let Err(err) = Self::resolve_default_value( + &array_schema.items, + field_name, + record_name, + names, + Some(value), + ) { + return Err(Details::GetDefaultArray { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + position, + inner: err.into_boxed_details(), + } + .into()); + } + } + Ok(()) + } + (Value::Array(_), _) => Err(Details::GetDefaultRecordField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + expected_ty: expected_type(field_schema), + actual_ty: "array".to_string(), + } + .into()), + (Value::Object(object), Schema::Record(record_schema)) => { + for field in &record_schema.fields { + if let Some(value) = object.get(&field.name) { + if let Err(err) = Self::resolve_default_value( + &field.schema, + field_name, + record_name, + names, + Some(value), + ) { + return Err(Details::GetDefaultRecordWrongField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + offending_field: field.name.clone(), + inner: err.into_boxed_details(), + } + .into()); + } + } else { + return Err(Details::GetDefaultRecordMissingField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + missing_field: field.name.clone(), + } + .into()); + } + } + for key in object.keys() { + if !record_schema.lookup.contains_key(key) { + return Err(Details::GetDefaultRecordUnknownField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + offending_field: key.clone(), + } + .into()); + } + } + Ok(()) + } + (Value::Object(object), Schema::Map(map_schema)) => { + for (key, value) in object.iter() { + if let Err(err) = Self::resolve_default_value( + &map_schema.types, + field_name, + record_name, + names, + Some(value), + ) { + return Err(Details::GetDefaultMap { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + key: key.clone(), + inner: err.into_boxed_details(), + } + .into()); + } + } + Ok(()) + } + (Value::Object(_), _) => Err(Details::GetDefaultRecordField { + field_name: field_name.to_string(), + record_name: record_name.to_string(), + expected_ty: expected_type(field_schema), + actual_ty: "object".to_string(), + } + .into()), + } + } else { + Ok(()) } - - Ok(()) } fn get_field_custom_attributes( diff --git a/avro/src/serde/derive.rs b/avro/src/serde/derive.rs index 451f0b1..9fcb4c4 100644 --- a/avro/src/serde/derive.rs +++ b/avro/src/serde/derive.rs @@ -328,6 +328,16 @@ pub trait AvroSchemaComponent { Self::get_schema_in_ctxt, ) } + + /// The default value of this type when used for a record field. + /// + /// `None` means no default value, which is also the default implementation. + /// + /// Implementations of this trait provided by this crate will use the [`Default::default`] + /// implementation of the type. + fn field_default() -> Option<serde_json::Value> { + None + } } /// Get the record fields from `schema_fn` without polluting `named_schemas` or causing duplicate names @@ -474,6 +484,10 @@ where macro_rules! impl_schema ( ($type:ty, $variant_constructor:expr) => ( + impl_schema!($type, $variant_constructor, <$type as Default>::default()); + ); + + ($type:ty, $variant_constructor:expr, $default_constructor:expr) => ( impl AvroSchemaComponent for $type { fn get_schema_in_ctxt(_: &mut Names, _: &Namespace) -> Schema { $variant_constructor @@ -482,6 +496,10 @@ macro_rules! impl_schema ( fn get_record_fields_in_ctxt(_: usize, _: &mut Names, _: &Namespace) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::from($default_constructor)) + } } ); ); @@ -497,8 +515,8 @@ impl_schema!(u32, Schema::Long); impl_schema!(f32, Schema::Float); impl_schema!(f64, Schema::Double); impl_schema!(String, Schema::String); -impl_schema!(str, Schema::String); -impl_schema!(char, Schema::String); +impl_schema!(str, Schema::String, String::default()); +impl_schema!(char, Schema::String, String::from(char::default())); macro_rules! impl_passthrough_schema ( ($type:ty where T: AvroSchemaComponent + ?Sized $(+ $bound:tt)*) => ( @@ -510,6 +528,10 @@ macro_rules! impl_passthrough_schema ( fn get_record_fields_in_ctxt(first_field_position: usize, named_schemas: &mut Names, enclosing_namespace: &Namespace) -> Option<Vec<RecordField>> { T::get_record_fields_in_ctxt(first_field_position, named_schemas, enclosing_namespace) } + + fn field_default() -> Option<serde_json::Value> { + T::field_default() + } } ); ); @@ -530,6 +552,10 @@ macro_rules! impl_array_schema ( fn get_record_fields_in_ctxt(_: usize, _: &mut Names, _: &Namespace) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Array(Vec::new())) + } } ); ); @@ -554,6 +580,10 @@ where ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Array(Vec::new())) + } } impl<T> AvroSchemaComponent for HashMap<String, T> @@ -571,6 +601,10 @@ where ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Object(serde_json::Map::new())) + } } impl<T> AvroSchemaComponent for Option<T> @@ -595,6 +629,10 @@ where ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::Null) + } } impl AvroSchemaComponent for core::time::Duration { @@ -629,6 +667,10 @@ impl AvroSchemaComponent for core::time::Duration { ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::String("\0\0\0\0\0\0\0\0\0\0\0\0".into())) + } } impl AvroSchemaComponent for uuid::Uuid { @@ -663,6 +705,12 @@ impl AvroSchemaComponent for uuid::Uuid { ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::String( + "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".into(), + )) + } } impl AvroSchemaComponent for u64 { @@ -695,6 +743,10 @@ impl AvroSchemaComponent for u64 { ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::String("\0\0\0\0\0\0\0\0".into())) + } } impl AvroSchemaComponent for u128 { @@ -727,6 +779,12 @@ impl AvroSchemaComponent for u128 { ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::String( + "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".into(), + )) + } } impl AvroSchemaComponent for i128 { @@ -759,6 +817,12 @@ impl AvroSchemaComponent for i128 { ) -> Option<Vec<RecordField>> { None } + + fn field_default() -> Option<serde_json::Value> { + Some(serde_json::Value::String( + "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".into(), + )) + } } #[cfg(test)] diff --git a/avro_derive/src/attributes/avro.rs b/avro_derive/src/attributes/avro.rs index ea171b5..fdf755e 100644 --- a/avro_derive/src/attributes/avro.rs +++ b/avro_derive/src/attributes/avro.rs @@ -21,6 +21,7 @@ //! Although a user will mostly use the Serde attributes, there are some Avro specific attributes //! a user can use. These add extra metadata to the generated schema. +use crate::attributes::FieldDefault; use crate::case::RenameRule; use darling::FromMeta; use proc_macro2::Span; @@ -53,6 +54,9 @@ pub struct ContainerAttributes { /// [`serde::ContainerAttributes::rename_all`]: crate::attributes::serde::ContainerAttributes::rename_all #[darling(default)] pub rename_all: RenameRule, + /// Set the default value if this schema is used as a field + #[darling(default)] + pub default: Option<String>, } impl ContainerAttributes { @@ -125,7 +129,7 @@ pub struct FieldAttributes { /// /// This is also used as the default when `skip_serializing{_if}` is used. #[darling(default)] - pub default: Option<String>, + pub default: FieldDefault, /// Deprecated. Use [`serde::FieldAttributes::alias`] instead. /// /// Adds the `aliases` field to the schema. diff --git a/avro_derive/src/attributes/mod.rs b/avro_derive/src/attributes/mod.rs index cc259f1..d05c8c3 100644 --- a/avro_derive/src/attributes/mod.rs +++ b/avro_derive/src/attributes/mod.rs @@ -17,7 +17,8 @@ use crate::case::RenameRule; use darling::{FromAttributes, FromMeta}; -use proc_macro2::Span; +use proc_macro2::{Span, TokenStream}; +use quote::quote; use syn::{AttrStyle, Attribute, Expr, Ident, Path, spanned::Spanned}; mod avro; @@ -30,6 +31,7 @@ pub struct NamedTypeOptions { pub aliases: Vec<String>, pub rename_all: RenameRule, pub transparent: bool, + pub default: TokenStream, } impl NamedTypeOptions { @@ -116,12 +118,29 @@ impl NamedTypeOptions { let doc = avro.doc.or_else(|| extract_rustdoc(attributes)); + let default = match avro.default { + None => quote! { None }, + Some(default_value) => { + let _: serde_json::Value = + serde_json::from_str(&default_value[..]).map_err(|e| { + vec![syn::Error::new( + ident.span(), + format!("Invalid avro default json: \n{e}"), + )] + })?; + quote! { + Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str())) + } + } + }; + Ok(Self { name: full_schema_name, doc, aliases: avro.alias, rename_all: serde.rename_all.serialize, transparent: serde.transparent, + default, }) } } @@ -210,11 +229,38 @@ impl With { } } } +/// How to get the default value for a value. +#[derive(Debug, PartialEq, Default)] +pub enum FieldDefault { + /// Use `<T as AvroSchemaComponent>::field_default`. + #[default] + Trait, + /// Don't set a default. + Disabled, + /// Use this JSON value. + Value(String), +} + +impl FromMeta for FieldDefault { + fn from_string(value: &str) -> darling::Result<Self> { + Ok(Self::Value(value.to_string())) + } + + fn from_bool(value: bool) -> darling::Result<Self> { + if value { + Err(darling::Error::custom( + "Expected `false` or a JSON string, got `true`", + )) + } else { + Ok(Self::Disabled) + } + } +} #[derive(Default)] pub struct FieldOptions { pub doc: Option<String>, - pub default: Option<String>, + pub default: FieldDefault, pub alias: Vec<String>, pub rename: Option<String>, pub skip: bool, @@ -274,11 +320,11 @@ impl FieldOptions { } if ((serde.skip_serializing && !serde.skip_deserializing) || serde.skip_serializing_if.is_some()) - && avro.default.is_none() + && avro.default == FieldDefault::Disabled { errors.push(syn::Error::new( span, - "`#[serde(skip_serializing)]` and `#[serde(skip_serializing_if)]` require `#[avro(default = \"..\")]`" + "`#[serde(skip_serializing)]` and `#[serde(skip_serializing_if)]` are incompatible with `#[avro(default = false)]`" )); } let with = match With::from_avro_and_serde(&avro.with, &serde.with, span) { diff --git a/avro_derive/src/lib.rs b/avro_derive/src/lib.rs index 64e8fc9..c9c3e37 100644 --- a/avro_derive/src/lib.rs +++ b/avro_derive/src/lib.rs @@ -40,7 +40,7 @@ use syn::{ }; use crate::{ - attributes::{FieldOptions, NamedTypeOptions, VariantOptions, With}, + attributes::{FieldDefault, FieldOptions, NamedTypeOptions, VariantOptions, With}, case::RenameRule, }; @@ -75,6 +75,7 @@ fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error> &input.generics, get_schema_impl, get_record_fields_impl, + named_type_options.default, )) } syn::Data::Enum(data_enum) => { @@ -93,6 +94,7 @@ fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error> &input.generics, inner, quote! { None }, + named_type_options.default, )) } syn::Data::Union(_) => Err(vec![syn::Error::new( @@ -108,6 +110,7 @@ fn create_trait_definition( generics: &Generics, get_schema_impl: TokenStream, get_record_fields_impl: TokenStream, + field_default_impl: TokenStream, ) -> TokenStream { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { @@ -120,6 +123,10 @@ fn create_trait_definition( fn get_record_fields_in_ctxt(mut field_position: usize, named_schemas: &mut ::apache_avro::schema::Names, enclosing_namespace: &::std::option::Option<::std::string::String>) -> ::std::option::Option<::std::vec::Vec<::apache_avro::schema::RecordField>> { #get_record_fields_impl } + + fn field_default() -> ::std::option::Option<::serde_json::Value> { + #field_default_impl + } } } } @@ -191,7 +198,9 @@ fn get_struct_schema_def( continue; } let default_value = match field_attrs.default { - Some(default_value) => { + FieldDefault::Disabled => quote! { None }, + FieldDefault::Trait => type_to_field_default_expr(&field.ty)?, + FieldDefault::Value(default_value) => { let _: serde_json::Value = serde_json::from_str(&default_value[..]) .map_err(|e| { vec![syn::Error::new( @@ -203,7 +212,6 @@ fn get_struct_schema_def( Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str())) } } - None => quote! { None }, }; let aliases = aliases(&field_attrs.alias); let schema_expr = get_field_schema_expr(&field, field_attrs.with)?; @@ -466,6 +474,28 @@ fn type_to_get_record_fields_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Err } } +fn type_to_field_default_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> { + match ty { + Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => { + Ok(quote! {<#ty as apache_avro::AvroSchemaComponent>::field_default()}) + } + Type::Ptr(_) => Err(vec![syn::Error::new_spanned( + ty, + "AvroSchema: derive does not support raw pointers", + )]), + Type::Tuple(_) => Err(vec![syn::Error::new_spanned( + ty, + "AvroSchema: derive does not support tuples", + )]), + _ => Err(vec![syn::Error::new_spanned( + ty, + format!( + "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}" + ), + )]), + } +} + fn default_enum_variant( data_enum: &syn::DataEnum, error_span: Span, diff --git a/avro_derive/tests/derive.rs b/avro_derive/tests/derive.rs index d6b1c4e..c425808 100644 --- a/avro_derive/tests/derive.rs +++ b/avro_derive/tests/derive.rs @@ -1365,6 +1365,7 @@ fn test_basic_struct_with_defaults() { #[avro(default = "true")] condition: bool, // no default value for 'c' + #[avro(default = false)] c: f64, #[avro(default = r#"{"a": 1, "b": 2}"#)] map: HashMap<String, i32>, @@ -1935,6 +1936,7 @@ fn avro_rs_397_uuid() { "fields": [ { "name":"baz", + "default": "\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000", "type":{ "type":"fixed", "logicalType":"uuid",
