Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom recursion limits at build time #785

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl<'a> CodeGenerator<'a> {
"#[derive(Clone, PartialEq, {}::Message)]\n",
self.config.prost_path.as_deref().unwrap_or("::prost")
));
self.append_recursion_limit(&fq_message_name);
self.push_indent();
self.buf.push_str("pub struct ");
self.buf.push_str(&to_upper_camel(&message_name));
Expand Down Expand Up @@ -272,6 +273,15 @@ impl<'a> CodeGenerator<'a> {
}
}

fn append_recursion_limit(&mut self, fq_message_name: &str) {
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
if let Some(limit) = self.config.recursion_limits.get_first(fq_message_name) {
push_indent(self.buf, self.depth);
self.buf.push_str(&format!("#[RecursionLimit({})]", limit));
self.buf.push('\n');
}
}

fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) {
assert_eq!(b'.', fq_message_name.as_bytes()[0]);
for attribute in self
Expand Down
21 changes: 21 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ pub struct Config {
bytes_type: PathMap<BytesType>,
type_attributes: PathMap<String>,
field_attributes: PathMap<String>,
recursion_limits: PathMap<u32>,
prost_types: bool,
strip_enum_prefix: bool,
out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -468,6 +469,25 @@ impl Config {
self
}

/// Configure a custom recursion limit for certain messages.
///
/// This defaults to 100, and can be disabled with the no-recursion-limit crate feature.
///
/// # Example
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// config.recursion_limit("my_messages.MyMessageType", 1000);
/// ```
pub fn recursion_limit<P>(&mut self, path: P, limit: u32) -> &mut Self
where
P: AsRef<str>,
{
self.recursion_limits
.insert(path.as_ref().to_string(), limit);
self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
Expand Down Expand Up @@ -1101,6 +1121,7 @@ impl default::Default for Config {
bytes_type: PathMap::default(),
type_attributes: PathMap::default(),
field_attributes: PathMap::default(),
recursion_limits: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
Expand Down
20 changes: 19 additions & 1 deletion prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {

let ident = input.ident;

let recursion_limit: u32 = if let Some(attr) = input
.attrs
.iter()
.find(|attr| attr.path.is_ident("RecursionLimit"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should make this one camel case like serde does https://serde.rs/container-attrs.html

so #[prost(recursion_limit = 5)] etc

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble getting a bespoke attribute parser working. Would it be okay to pull in darling as a dependency? https://crates.io/crates/darling

Another option is to skip proper attribute parsing and only handle the single attribute we have for now.

{
if let syn::Lit::Int(attr) = attr.parse_args().unwrap() {
attr.base10_parse().unwrap()
} else {
panic!("unexpected RecursionLimit type: {:?}", attr)
}
} else {
100
};

let variant_data = match input.data {
Data::Struct(variant_data) => variant_data,
Data::Enum(..) => bail!("Message can not be derived for an enum"),
Expand Down Expand Up @@ -187,6 +201,10 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {

let expanded = quote! {
impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
fn recursion_limit() -> u32 {
#recursion_limit
}

#[allow(unused_variables)]
fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
#(#encode)*
Expand Down Expand Up @@ -238,7 +256,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
Ok(expanded.into())
}

#[proc_macro_derive(Message, attributes(prost))]
#[proc_macro_derive(Message, attributes(prost, RecursionLimit))]
pub fn message(input: TokenStream) -> TokenStream {
try_message(input).unwrap()
}
Expand Down
26 changes: 11 additions & 15 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,29 +195,25 @@ pub struct DecodeContext {
/// How many times we can recurse in the current decode stack before we hit
/// the recursion limit.
///
/// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
/// customized. The recursion limit can be ignored by building the Prost
/// crate with the `no-recursion-limit` feature.
/// It defaults to 100 and can be changed using `prost_build::recursion_limit`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So because this is hidden this doc won't actually be readable. So we need to make sure this is documented at the lib level of prost.

/// or it can be disabled entirely using the `no-recursion-limit` feature.
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: u32,
pub recurse_count: u32,
seanlinsley marked this conversation as resolved.
Show resolved Hide resolved
}

#[cfg(not(feature = "no-recursion-limit"))]
impl Default for DecodeContext {
#[inline]
fn default() -> DecodeContext {
impl DecodeContext {
pub(crate) fn new(recursion_limit: u32) -> DecodeContext {
DecodeContext {
recurse_count: crate::RECURSION_LIMIT,
#[cfg(not(feature = "no-recursion-limit"))]
recurse_count: recursion_limit,
}
}
}

impl DecodeContext {
/// Call this function before recursively decoding.
///
/// There is no `exit` function since this function creates a new `DecodeContext`
/// to be used at the next level of recursion. Continue to use the old context
// at the previous level of recursion.
/// at the previous level of recursion.
#[cfg(not(feature = "no-recursion-limit"))]
#[inline]
pub(crate) fn enter_recursion(&self) -> DecodeContext {
Expand Down Expand Up @@ -1503,7 +1499,7 @@ mod test {
wire_type,
&mut roundtrip_value,
&mut buf,
DecodeContext::default(),
DecodeContext::new(100),
)
.map_err(|error| TestCaseError::fail(error.to_string()))?;

Expand Down Expand Up @@ -1575,7 +1571,7 @@ mod test {
wire_type,
&mut roundtrip_value,
&mut buf,
DecodeContext::default(),
DecodeContext::new(100),
)
.map_err(|error| TestCaseError::fail(error.to_string()))?;
}
Expand All @@ -1594,7 +1590,7 @@ mod test {
WireType::LengthDelimited,
&mut s,
&mut &buf[..],
DecodeContext::default(),
DecodeContext::new(100),
);
r.expect_err("must be an error");
assert!(s.is_empty());
Expand Down
5 changes: 0 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ use bytes::{Buf, BufMut};

use crate::encoding::{decode_varint, encode_varint, encoded_len_varint};

// See `encoding::DecodeContext` for more info.
// 100 is the default recursion limit in the C++ implementation.
#[cfg(not(feature = "no-recursion-limit"))]
const RECURSION_LIMIT: u32 = 100;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it sounds like this will never change, I opted to inline it everywhere it's used. That allows documentation to say explicitly that the default recursion limit is 100 instead of requiring that users look up this constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually kinda like having this as a constant. I wonder if we could just make a constants module that contains just this one. And we can then have all the deep dive docs on the recursion implementation there and then we just need to link to there from the lib doc page. I feel like that would make it easier to maintain down the line and centralize it a bit.


/// Encodes a length delimiter to the buffer.
///
/// See [Message.encode_length_delimited] for more info.
Expand Down
14 changes: 12 additions & 2 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ use crate::EncodeError;

/// A Protocol Buffers message.
pub trait Message: Debug + Send + Sync {
/// The recursion limit for decoding protobuf messages.
///
/// Defaults to 100. Can be customized in your build.rs or by using the no-recursion-limit crate feature.
fn recursion_limit() -> u32
LucioFranco marked this conversation as resolved.
Show resolved Hide resolved
where
Self: Sized;

/// Encodes the message to a buffer.
///
/// This method will panic if the buffer has insufficient capacity.
Expand Down Expand Up @@ -135,7 +142,7 @@ pub trait Message: Debug + Send + Sync {
B: Buf,
Self: Sized,
{
let ctx = DecodeContext::default();
let ctx = DecodeContext::new(Self::recursion_limit());
while buf.has_remaining() {
let (tag, wire_type) = decode_key(&mut buf)?;
self.merge_field(tag, wire_type, &mut buf, ctx.clone())?;
Expand All @@ -154,7 +161,7 @@ pub trait Message: Debug + Send + Sync {
WireType::LengthDelimited,
self,
&mut buf,
DecodeContext::default(),
DecodeContext::new(Self::recursion_limit()),
)
}

Expand All @@ -166,6 +173,9 @@ impl<M> Message for Box<M>
where
M: Message,
{
fn recursion_limit() -> u32 {
M::recursion_limit()
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down
34 changes: 34 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ use crate::{

/// `google.protobuf.BoolValue`
impl Message for bool {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -58,6 +61,9 @@ impl Message for bool {

/// `google.protobuf.UInt32Value`
impl Message for u32 {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -96,6 +102,9 @@ impl Message for u32 {

/// `google.protobuf.UInt64Value`
impl Message for u64 {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -134,6 +143,10 @@ impl Message for u64 {

/// `google.protobuf.Int32Value`
impl Message for i32 {
fn recursion_limit() -> u32 {
0
}

fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -172,6 +185,9 @@ impl Message for i32 {

/// `google.protobuf.Int64Value`
impl Message for i64 {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -210,6 +226,9 @@ impl Message for i64 {

/// `google.protobuf.FloatValue`
impl Message for f32 {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -248,6 +267,9 @@ impl Message for f32 {

/// `google.protobuf.DoubleValue`
impl Message for f64 {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -286,6 +308,9 @@ impl Message for f64 {

/// `google.protobuf.StringValue`
impl Message for String {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -324,6 +349,9 @@ impl Message for String {

/// `google.protobuf.BytesValue`
impl Message for Vec<u8> {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -362,6 +390,9 @@ impl Message for Vec<u8> {

/// `google.protobuf.BytesValue`
impl Message for Bytes {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, buf: &mut B)
where
B: BufMut,
Expand Down Expand Up @@ -400,6 +431,9 @@ impl Message for Bytes {

/// `google.protobuf.Empty`
impl Message for () {
fn recursion_limit() -> u32 {
0
}
fn encode_raw<B>(&self, _buf: &mut B)
where
B: BufMut,
Expand Down
1 change: 1 addition & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn main() {
.compile_protos(&[src.join("ident_conversion.proto")], includes)
.unwrap();

config.recursion_limit("nesting.E", 200);
config
.compile_protos(&[src.join("nesting.proto")], includes)
.unwrap();
Expand Down
21 changes: 21 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,27 @@ mod tests {
assert!(build_and_roundtrip(101).is_err());
}

#[test]
fn test_deep_nesting_with_custom_recursion_limit() {
fn build_and_roundtrip(depth: usize) -> Result<(), prost::DecodeError> {
use crate::nesting::E;

let mut e = Box::new(E::default());
for _ in 0..depth {
let mut next = Box::new(E::default());
next.e = Some(e);
e = next;
}

let mut buf = Vec::new();
e.encode(&mut buf).unwrap();
E::decode(&*buf).map(|_| ())
}

assert!(build_and_roundtrip(200).is_ok());
assert!(build_and_roundtrip(201).is_err());
}

#[test]
fn test_deep_nesting_oneof() {
fn build_and_roundtrip(depth: usize) -> Result<(), prost::DecodeError> {
Expand Down
10 changes: 10 additions & 0 deletions tests/src/nesting.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,13 @@ message C {
message D {
map<string, D> m = 1;
}

message E {
E e = 1;
repeated E repeated_e = 2;
map<int32, E> map_e = 3;

B b = 4;
repeated B repeated_b = 5;
map<int32, B> map_b = 6;
}