diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 4967e35d1..b306df112 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1042,7 +1042,11 @@ fn deserialize_struct( let fields_stmt = if has_flatten { None } else { - let field_names = deserialized_fields.iter().flat_map(|field| field.aliases); + let implied_fields = cattrs.implied().iter().map(|(key, _)| key); + let field_names = deserialized_fields + .iter() + .flat_map(|field| field.aliases) + .chain(implied_fields); Some(quote! { #[doc(hidden)] @@ -1155,7 +1159,11 @@ fn deserialize_struct_in_place( }; let visit_seq = Stmts(deserialize_seq_in_place(params, fields, cattrs, expecting)); let visit_map = Stmts(deserialize_map_in_place(params, fields, cattrs)); - let field_names = deserialized_fields.iter().flat_map(|field| field.aliases); + let implied_fields = cattrs.implied().iter().map(|(key, _)| key); + let field_names = deserialized_fields + .iter() + .flat_map(|field| field.aliases) + .chain(implied_fields); let type_name = cattrs.name().deserialize_name(); let in_place_impl_generics = de_impl_generics.in_place(); @@ -1276,6 +1284,7 @@ fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { true, None, fallthrough, + None, )); (variants_stmt, variant_visitor) @@ -2022,8 +2031,29 @@ fn deserialize_generated_identifier( is_variant: bool, ignore_variant: Option, fallthrough: Option, + implied_fields: Option<&[(Name, Name)]>, ) -> Fragment { let this_value = quote!(__Field); + + let implied_fields: Vec<_> = implied_fields + .unwrap_or_default() + .iter() + .map(|(key, value)| { + let ident = Ident::new(&format!("__field_{}", key.value), key.span); + (key, value, ident) + }) + .collect(); + + let implied_idents: Vec<_> = implied_fields.iter().map(|(_, _, ident)| ident).collect(); + let implied_field_mapping: Vec<_> = implied_fields + .iter() + .map(|(key, _, ident)| { + quote! { + #key => _serde::__private::Ok(#this_value::#ident), + } + }) + .collect(); + let field_idents: &Vec<_> = &deserialized_fields .iter() .map(|field| &field.ident) @@ -2037,6 +2067,7 @@ fn deserialize_generated_identifier( None, !is_variant && has_flatten, None, + &implied_field_mapping, )); let lifetime = if !is_variant && has_flatten { @@ -2049,6 +2080,7 @@ fn deserialize_generated_identifier( #[allow(non_camel_case_types)] #[doc(hidden)] enum __Field #lifetime { + #(#implied_idents,)* #(#field_idents,)* #ignore_variant } @@ -2101,6 +2133,7 @@ fn deserialize_field_identifier( false, ignore_variant, fallthrough, + Some(cattrs.implied()), )) } @@ -2162,7 +2195,11 @@ fn deserialize_custom_identifier( }) .collect(); - let names = idents_aliases.iter().flat_map(|variant| variant.aliases); + let implied_fields = cattrs.implied().iter().map(|(key, _)| key); + let names = idents_aliases + .iter() + .flat_map(|variant| variant.aliases) + .chain(implied_fields); let names_const = if fallthrough.is_some() { None @@ -2191,6 +2228,7 @@ fn deserialize_custom_identifier( fallthrough_borrowed, false, cattrs.expecting(), + &[], )); quote_block! { @@ -2225,6 +2263,7 @@ fn deserialize_identifier( fallthrough_borrowed: Option, collect_other_fields: bool, expecting: Option<&str>, + additional_mapping: &[TokenStream], ) -> Fragment { let str_mapping = deserialized_fields.iter().map(|field| { let ident = &field.ident; @@ -2484,6 +2523,7 @@ fn deserialize_identifier( { match __value { #(#str_mapping)* + #(#additional_mapping)* _ => { #value_as_str_content #fallthrough_arm @@ -2523,6 +2563,21 @@ fn deserialize_map( .map(|(i, field)| (field, field_i(i))) .collect(); + let implied_fields: Vec<_> = cattrs + .implied() + .iter() + .map(|(key, value)| { + let ident = Ident::new(&format!("__field_{}", key.value), key.span); + (key, value, ident) + }) + .collect(); + + let implied_let_values = implied_fields.iter().map(|(_, _, ident)| { + quote! { + let mut #ident: _serde::__private::Option<::std::string::String> = _serde::__private::None; + } + }); + // Declare each field that will be deserialized. let let_values = fields_names .iter() @@ -2532,7 +2587,8 @@ fn deserialize_map( quote! { let mut #name: _serde::__private::Option<#field_ty> = _serde::__private::None; } - }); + }) + .chain(implied_let_values); // Collect contents for flatten fields into a buffer let let_collect = if has_flatten { @@ -2546,6 +2602,31 @@ fn deserialize_map( None }; + let implied_values_arm = implied_fields.iter().map(|(key, value, ident)| { + let visit = { + let span = key.span(); + let func = + quote_spanned!(span=> _serde::de::MapAccess::next_value::<::std::string::String>); + + quote! { + if _serde::__private::Option::is_some(&#ident) { + return _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#key)); + } + let inner = #func(&mut __map)?; + if inner != #value { + return _serde::__private::Err(_serde::de::Error::invalid_value(_serde::de::Unexpected::Str(&inner), &#value)); + } + #ident = _serde::__private::Some(inner); + } + }; + + quote! { + __Field::#ident => { + #visit + } + } + }); + // Match arms to extract a value for a field. let value_arms = fields_names .iter() @@ -2584,7 +2665,7 @@ fn deserialize_map( #name = _serde::__private::Some(#visit); } } - }); + }).chain(implied_values_arm); // Visit ignored values to consume them let ignored_arm = if has_flatten { @@ -2623,6 +2704,21 @@ fn deserialize_map( } }; + let implied_extract_values = implied_fields.iter().map(|(key, _, ident)| { + let span = key.span(); + let func = quote_spanned!(span=> _serde::__private::de::missing_field); + let missing_expr = quote! { + #func(#key)? + }; + + quote! { + match #ident.take() { + _serde::__private::Some(_) => {}, + _serde::__private::None => #missing_expr + }; + } + }); + let extract_values = fields_names .iter() .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) @@ -2635,7 +2731,8 @@ fn deserialize_map( _serde::__private::None => #missing_expr }; } - }); + }) + .chain(implied_extract_values); let extract_collected = fields_names .iter() @@ -2749,6 +2846,21 @@ fn deserialize_map_in_place( .map(|(i, field)| (field, field_i(i))) .collect(); + let implied_fields: Vec<_> = cattrs + .implied() + .iter() + .map(|(key, value)| { + let ident = Ident::new(&format!("__field_{}", key.value), key.span); + (key, value, ident) + }) + .collect(); + + let implied_let_flags = implied_fields.iter().map(|(_, _, name)| { + quote! { + let mut #name: bool = false; + } + }); + // For deserialize_in_place, declare booleans for each field that will be // deserialized. let let_flags = fields_names @@ -2758,7 +2870,32 @@ fn deserialize_map_in_place( quote! { let mut #name: bool = false; } - }); + }) + .chain(implied_let_flags); + + let implied_values_arm = implied_fields.iter().map(|(key, value, ident)| { + let visit = { + let span = key.span(); + let func = + quote_spanned!(span=> _serde::de::MapAccess::next_value::<::std::string::String>); + quote! { + let inner = #func(&mut __map)?; + if inner != #value { + return Err(de::Error::custom("Invalid method value")); + } + } + }; + + quote! { + __Field::#ident => { + if #ident { + return _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#key)); + } + #visit + #ident = true; + } + } + }); // Match arms to extract a value for a field. let value_arms_from = fields_names @@ -2796,7 +2933,7 @@ fn deserialize_map_in_place( #name = true; } } - }); + }).chain(implied_values_arm); // Visit ignored values to consume them let ignored_arm = if cattrs.deny_unknown_fields() { diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index 6d846ed01..3496fde86 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -170,6 +170,7 @@ pub struct Container { identifier: Identifier, serde_path: Option, is_packed: bool, + implied: Vec<(Name, Name)>, /// Error message generated when type can't be deserialized expecting: Option, non_exhaustive: bool, @@ -259,6 +260,7 @@ impl Container { let mut serde_path = Attr::none(cx, CRATE); let mut expecting = Attr::none(cx, EXPECTING); let mut non_exhaustive = false; + let mut implied = VecAttr::none(cx, IMPLIED); for attr in &item.attrs { if attr.path() != SERDE { @@ -491,6 +493,11 @@ impl Container { if let Some(s) = get_lit_str(cx, EXPECTING, &meta)? { expecting.set(&meta.path, s.value()); } + } else if meta.path == IMPLIED { + // #[serde(implied(key = "key", value = "value")] + if let Some((key, value)) = parse_key_value_names(cx,IMPLIED,&meta)? { + implied.insert(&meta.path, (key, value)); + } } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( @@ -540,6 +547,7 @@ impl Container { identifier: decide_identifier(cx, item, field_identifier, variant_identifier), serde_path: serde_path.get(), is_packed, + implied: implied.get(), expecting: expecting.get(), non_exhaustive, } @@ -601,6 +609,10 @@ impl Container { self.is_packed } + pub fn implied(&self) -> &[(Name, Name)] { + &self.implied + } + pub fn identifier(&self) -> Identifier { self.identifier } @@ -1547,6 +1559,40 @@ fn parse_lit_into_ty( }) } +fn parse_key_value_names( + cx: &Ctxt, + attr_name: Symbol, + meta: &ParseNestedMeta, +) -> syn::Result> { + let mut key = None; + let mut value = None; + + meta.parse_nested_meta(|meta| { + if meta.path == KEY { + if let Some(v) = get_lit_str2(cx, attr_name, KEY, &meta)? { + key = Some(v) + } + } else if meta.path == VALUE { + if let Some(v) = get_lit_str2(cx, attr_name, VALUE, &meta)? { + value = Some(v); + } + } else { + return Err(meta.error(format_args!( + "malformed {0} attribute, expected `{0}(key = ..., value = ...)`", + attr_name, + ))); + } + Ok(()) + })?; + + Ok( + match (key.as_ref().map(Name::from), value.as_ref().map(Name::from)) { + (Some(key), Some(value)) => Some((key, value)), + _ => None, + }, + ) +} + // Parses a string literal like "'a + 'b + 'c" containing a nonempty list of // lifetimes separated by `+`. fn parse_lit_into_lifetimes( diff --git a/serde_derive/src/internals/symbol.rs b/serde_derive/src/internals/symbol.rs index 59ef8de7c..4b1c9ec11 100644 --- a/serde_derive/src/internals/symbol.rs +++ b/serde_derive/src/internals/symbol.rs @@ -18,7 +18,9 @@ pub const FIELD_IDENTIFIER: Symbol = Symbol("field_identifier"); pub const FLATTEN: Symbol = Symbol("flatten"); pub const FROM: Symbol = Symbol("from"); pub const GETTER: Symbol = Symbol("getter"); +pub const KEY: Symbol = Symbol("key"); pub const INTO: Symbol = Symbol("into"); +pub const IMPLIED: Symbol = Symbol("implied"); pub const NON_EXHAUSTIVE: Symbol = Symbol("non_exhaustive"); pub const OTHER: Symbol = Symbol("other"); pub const REMOTE: Symbol = Symbol("remote"); @@ -37,6 +39,7 @@ pub const TAG: Symbol = Symbol("tag"); pub const TRANSPARENT: Symbol = Symbol("transparent"); pub const TRY_FROM: Symbol = Symbol("try_from"); pub const UNTAGGED: Symbol = Symbol("untagged"); +pub const VALUE: Symbol = Symbol("value"); pub const VARIANT_IDENTIFIER: Symbol = Symbol("variant_identifier"); pub const WITH: Symbol = Symbol("with"); diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index c92800d51..1bee2f692 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -327,8 +327,11 @@ fn serialize_struct_as_struct( fields: &[Field], cattrs: &attr::Container, ) -> Fragment { - let serialize_fields = + let mut implied_fields = + serialize_implied_fields(cattrs.implied(), &StructTrait::SerializeStruct); + let mut serialize_fields = serialize_struct_visitor(fields, params, false, &StructTrait::SerializeStruct); + serialize_fields.append(&mut implied_fields); let type_name = cattrs.name().serialize_name(); @@ -342,6 +345,8 @@ fn serialize_struct_as_struct( let let_mut = mut_if(serialized_fields.peek().is_some() || tag_field_exists); + let implied_fields = cattrs.implied().len(); + let len = serialized_fields .map(|field| match field.attrs.skip_serializing_if() { None => quote!(1), @@ -351,7 +356,7 @@ fn serialize_struct_as_struct( } }) .fold( - quote!(#tag_field_exists as usize), + quote!(#tag_field_exists as usize + #implied_fields as usize), |sum, expr| quote!(#sum + #expr), ); @@ -1165,6 +1170,25 @@ fn serialize_struct_visitor( .collect() } +fn serialize_implied_fields( + fields: &[(Name, Name)], + struct_trait: &StructTrait, +) -> Vec { + fields + .iter() + .map(|(key, value)| { + let ser = { + let func = struct_trait.serialize_field(key.span()); + quote! { + #func(&mut __serde_state, #key, #value)?; + } + }; + + ser + }) + .collect() +} + fn wrap_serialize_field_with( params: &Parameters, field_ty: &syn::Type, diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 878c88981..4772eb615 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -1471,6 +1471,31 @@ fn test_missing_renamed_field_enum() { ); } +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(implied(key = "test", value = "one"))] +struct ImpliedStruct { + val: (), +} + +#[test] +fn test_implied() { + assert_tokens( + &ImpliedStruct { val: () }, + &[ + Token::Struct { + name: "ImpliedStruct", + len: 2, + }, + Token::Str("val"), + Token::Unit, + Token::Str("test"), + Token::Str("one"), + Token::StructEnd, + ], + ); +} + #[derive(Debug, PartialEq, Deserialize)] enum InvalidLengthEnum { A(i32, i32, i32), diff --git a/test_suite/tests/ui/malformed/implied.rs b/test_suite/tests/ui/malformed/implied.rs new file mode 100644 index 000000000..8d6368742 --- /dev/null +++ b/test_suite/tests/ui/malformed/implied.rs @@ -0,0 +1,9 @@ +use serde_derive::Serialize; + +#[derive(Serialize)] +#[serde(implied(unknown))] +struct S { + x: (), +} + +fn main() {} diff --git a/test_suite/tests/ui/malformed/implied.stderr b/test_suite/tests/ui/malformed/implied.stderr new file mode 100644 index 000000000..bf790b634 --- /dev/null +++ b/test_suite/tests/ui/malformed/implied.stderr @@ -0,0 +1,5 @@ +error: malformed implied attribute, expected `implied(key = ..., value = ...)` + --> tests/ui/malformed/implied.rs:4:17 + | +4 | #[serde(implied(unknown))] + | ^^^^^^^