Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement #[serde(implied(key = "key", value = "value")] for adding static values #2908

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 145 additions & 8 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,11 @@ fn deserialize_struct(
let fields_stmt = if has_flatten {
None
} else {
let field_names = deserialized_fields.iter().flat_map(|field| field.aliases);
let implied_fields = cattrs.implied().iter().map(|(key, _)| key);
let field_names = deserialized_fields
.iter()
.flat_map(|field| field.aliases)
.chain(implied_fields);

Some(quote! {
#[doc(hidden)]
Expand Down Expand Up @@ -1155,7 +1159,11 @@ fn deserialize_struct_in_place(
};
let visit_seq = Stmts(deserialize_seq_in_place(params, fields, cattrs, expecting));
let visit_map = Stmts(deserialize_map_in_place(params, fields, cattrs));
let field_names = deserialized_fields.iter().flat_map(|field| field.aliases);
let implied_fields = cattrs.implied().iter().map(|(key, _)| key);
let field_names = deserialized_fields
.iter()
.flat_map(|field| field.aliases)
.chain(implied_fields);
let type_name = cattrs.name().deserialize_name();

let in_place_impl_generics = de_impl_generics.in_place();
Expand Down Expand Up @@ -1276,6 +1284,7 @@ fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) {
true,
None,
fallthrough,
None,
));

(variants_stmt, variant_visitor)
Expand Down Expand Up @@ -2022,8 +2031,29 @@ fn deserialize_generated_identifier(
is_variant: bool,
ignore_variant: Option<TokenStream>,
fallthrough: Option<TokenStream>,
implied_fields: Option<&[(Name, Name)]>,
) -> Fragment {
let this_value = quote!(__Field);

let implied_fields: Vec<_> = implied_fields
.unwrap_or_default()
.iter()
.map(|(key, value)| {
let ident = Ident::new(&format!("__field_{}", key.value), key.span);
(key, value, ident)
})
.collect();

let implied_idents: Vec<_> = implied_fields.iter().map(|(_, _, ident)| ident).collect();
let implied_field_mapping: Vec<_> = implied_fields
.iter()
.map(|(key, _, ident)| {
quote! {
#key => _serde::__private::Ok(#this_value::#ident),
}
})
.collect();

let field_idents: &Vec<_> = &deserialized_fields
.iter()
.map(|field| &field.ident)
Expand All @@ -2037,6 +2067,7 @@ fn deserialize_generated_identifier(
None,
!is_variant && has_flatten,
None,
&implied_field_mapping,
));

let lifetime = if !is_variant && has_flatten {
Expand All @@ -2049,6 +2080,7 @@ fn deserialize_generated_identifier(
#[allow(non_camel_case_types)]
#[doc(hidden)]
enum __Field #lifetime {
#(#implied_idents,)*
#(#field_idents,)*
#ignore_variant
}
Expand Down Expand Up @@ -2101,6 +2133,7 @@ fn deserialize_field_identifier(
false,
ignore_variant,
fallthrough,
Some(cattrs.implied()),
))
}

Expand Down Expand Up @@ -2162,7 +2195,11 @@ fn deserialize_custom_identifier(
})
.collect();

let names = idents_aliases.iter().flat_map(|variant| variant.aliases);
let implied_fields = cattrs.implied().iter().map(|(key, _)| key);
let names = idents_aliases
.iter()
.flat_map(|variant| variant.aliases)
.chain(implied_fields);

let names_const = if fallthrough.is_some() {
None
Expand Down Expand Up @@ -2191,6 +2228,7 @@ fn deserialize_custom_identifier(
fallthrough_borrowed,
false,
cattrs.expecting(),
&[],
));

quote_block! {
Expand Down Expand Up @@ -2225,6 +2263,7 @@ fn deserialize_identifier(
fallthrough_borrowed: Option<TokenStream>,
collect_other_fields: bool,
expecting: Option<&str>,
additional_mapping: &[TokenStream],
) -> Fragment {
let str_mapping = deserialized_fields.iter().map(|field| {
let ident = &field.ident;
Expand Down Expand Up @@ -2484,6 +2523,7 @@ fn deserialize_identifier(
{
match __value {
#(#str_mapping)*
#(#additional_mapping)*
_ => {
#value_as_str_content
#fallthrough_arm
Expand Down Expand Up @@ -2523,6 +2563,21 @@ fn deserialize_map(
.map(|(i, field)| (field, field_i(i)))
.collect();

let implied_fields: Vec<_> = cattrs
.implied()
.iter()
.map(|(key, value)| {
let ident = Ident::new(&format!("__field_{}", key.value), key.span);
(key, value, ident)
})
.collect();

let implied_let_values = implied_fields.iter().map(|(_, _, ident)| {
quote! {
let mut #ident: _serde::__private::Option<::std::string::String> = _serde::__private::None;
}
});

// Declare each field that will be deserialized.
let let_values = fields_names
.iter()
Expand All @@ -2532,7 +2587,8 @@ fn deserialize_map(
quote! {
let mut #name: _serde::__private::Option<#field_ty> = _serde::__private::None;
}
});
})
.chain(implied_let_values);

// Collect contents for flatten fields into a buffer
let let_collect = if has_flatten {
Expand All @@ -2546,6 +2602,31 @@ fn deserialize_map(
None
};

let implied_values_arm = implied_fields.iter().map(|(key, value, ident)| {
let visit = {
let span = key.span();
let func =
quote_spanned!(span=> _serde::de::MapAccess::next_value::<::std::string::String>);

quote! {
if _serde::__private::Option::is_some(&#ident) {
return _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#key));
}
let inner = #func(&mut __map)?;
if inner != #value {
return _serde::__private::Err(_serde::de::Error::invalid_value(_serde::de::Unexpected::Str(&inner), &#value));
}
#ident = _serde::__private::Some(inner);
}
};

quote! {
__Field::#ident => {
#visit
}
}
});

// Match arms to extract a value for a field.
let value_arms = fields_names
.iter()
Expand Down Expand Up @@ -2584,7 +2665,7 @@ fn deserialize_map(
#name = _serde::__private::Some(#visit);
}
}
});
}).chain(implied_values_arm);

// Visit ignored values to consume them
let ignored_arm = if has_flatten {
Expand Down Expand Up @@ -2623,6 +2704,21 @@ fn deserialize_map(
}
};

let implied_extract_values = implied_fields.iter().map(|(key, _, ident)| {
let span = key.span();
let func = quote_spanned!(span=> _serde::__private::de::missing_field);
let missing_expr = quote! {
#func(#key)?
};

quote! {
match #ident.take() {
_serde::__private::Some(_) => {},
_serde::__private::None => #missing_expr
};
}
});

let extract_values = fields_names
.iter()
.filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten())
Expand All @@ -2635,7 +2731,8 @@ fn deserialize_map(
_serde::__private::None => #missing_expr
};
}
});
})
.chain(implied_extract_values);

let extract_collected = fields_names
.iter()
Expand Down Expand Up @@ -2749,6 +2846,21 @@ fn deserialize_map_in_place(
.map(|(i, field)| (field, field_i(i)))
.collect();

let implied_fields: Vec<_> = cattrs
.implied()
.iter()
.map(|(key, value)| {
let ident = Ident::new(&format!("__field_{}", key.value), key.span);
(key, value, ident)
})
.collect();

let implied_let_flags = implied_fields.iter().map(|(_, _, name)| {
quote! {
let mut #name: bool = false;
}
});

// For deserialize_in_place, declare booleans for each field that will be
// deserialized.
let let_flags = fields_names
Expand All @@ -2758,7 +2870,32 @@ fn deserialize_map_in_place(
quote! {
let mut #name: bool = false;
}
});
})
.chain(implied_let_flags);

let implied_values_arm = implied_fields.iter().map(|(key, value, ident)| {
let visit = {
let span = key.span();
let func =
quote_spanned!(span=> _serde::de::MapAccess::next_value::<::std::string::String>);
quote! {
let inner = #func(&mut __map)?;
if inner != #value {
return Err(de::Error::custom("Invalid method value"));
}
}
};

quote! {
__Field::#ident => {
if #ident {
return _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#key));
}
#visit
#ident = true;
}
}
});

// Match arms to extract a value for a field.
let value_arms_from = fields_names
Expand Down Expand Up @@ -2796,7 +2933,7 @@ fn deserialize_map_in_place(
#name = true;
}
}
});
}).chain(implied_values_arm);

// Visit ignored values to consume them
let ignored_arm = if cattrs.deny_unknown_fields() {
Expand Down
46 changes: 46 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ pub struct Container {
identifier: Identifier,
serde_path: Option<syn::Path>,
is_packed: bool,
implied: Vec<(Name, Name)>,
/// Error message generated when type can't be deserialized
expecting: Option<String>,
non_exhaustive: bool,
Expand Down Expand Up @@ -259,6 +260,7 @@ impl Container {
let mut serde_path = Attr::none(cx, CRATE);
let mut expecting = Attr::none(cx, EXPECTING);
let mut non_exhaustive = false;
let mut implied = VecAttr::none(cx, IMPLIED);

for attr in &item.attrs {
if attr.path() != SERDE {
Expand Down Expand Up @@ -491,6 +493,11 @@ impl Container {
if let Some(s) = get_lit_str(cx, EXPECTING, &meta)? {
expecting.set(&meta.path, s.value());
}
} else if meta.path == IMPLIED {
// #[serde(implied(key = "key", value = "value")]
if let Some((key, value)) = parse_key_value_names(cx,IMPLIED,&meta)? {
implied.insert(&meta.path, (key, value));
}
} else {
let path = meta.path.to_token_stream().to_string().replace(' ', "");
return Err(
Expand Down Expand Up @@ -540,6 +547,7 @@ impl Container {
identifier: decide_identifier(cx, item, field_identifier, variant_identifier),
serde_path: serde_path.get(),
is_packed,
implied: implied.get(),
expecting: expecting.get(),
non_exhaustive,
}
Expand Down Expand Up @@ -601,6 +609,10 @@ impl Container {
self.is_packed
}

pub fn implied(&self) -> &[(Name, Name)] {
&self.implied
}

pub fn identifier(&self) -> Identifier {
self.identifier
}
Expand Down Expand Up @@ -1547,6 +1559,40 @@ fn parse_lit_into_ty(
})
}

fn parse_key_value_names(
cx: &Ctxt,
attr_name: Symbol,
meta: &ParseNestedMeta,
) -> syn::Result<Option<(Name, Name)>> {
let mut key = None;
let mut value = None;

meta.parse_nested_meta(|meta| {
if meta.path == KEY {
if let Some(v) = get_lit_str2(cx, attr_name, KEY, &meta)? {
key = Some(v)
}
} else if meta.path == VALUE {
if let Some(v) = get_lit_str2(cx, attr_name, VALUE, &meta)? {
value = Some(v);
}
} else {
return Err(meta.error(format_args!(
"malformed {0} attribute, expected `{0}(key = ..., value = ...)`",
attr_name,
)));
}
Ok(())
})?;

Ok(
match (key.as_ref().map(Name::from), value.as_ref().map(Name::from)) {
(Some(key), Some(value)) => Some((key, value)),
_ => None,
},
)
}

// Parses a string literal like "'a + 'b + 'c" containing a nonempty list of
// lifetimes separated by `+`.
fn parse_lit_into_lifetimes(
Expand Down
Loading