Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typed catchers #2814

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
56e7fa6
Initial brush
the10thWiz Jun 20, 2024
c0ad038
Working example
the10thWiz Jun 26, 2024
99e2109
Add error type to logs
the10thWiz Jun 29, 2024
09c56c7
Major improvements
the10thWiz Jun 29, 2024
b68900f
Fix whitespace
the10thWiz Jun 30, 2024
4cb3a3a
Revert local changes to scripts dir
the10thWiz Jun 30, 2024
ac3a7fa
Ensure examples pass CI
the10thWiz Jun 30, 2024
eaea6f6
Add Transient impl for serde::json::Error
the10thWiz Jun 30, 2024
1308c19
tmp
the10thWiz Jul 1, 2024
f8c8bb8
Rework catch attribute
the10thWiz Jul 2, 2024
dea224f
Update tests to use new #[catch] macro
the10thWiz Jul 2, 2024
7b8689c
Update transient and use new features in examples
the10thWiz Jul 13, 2024
fb796fc
Update guide
the10thWiz Jul 13, 2024
af68f5e
Major changes
the10thWiz Aug 17, 2024
a59cb04
Update core server code to use new error trait
the10thWiz Aug 17, 2024
6427db2
Updates to improve many aspects
the10thWiz Aug 24, 2024
6d06ac7
Update code to work properly with borrowed errors
the10thWiz Sep 3, 2024
04ae827
Fix formatting issues
the10thWiz Sep 3, 2024
99bba53
Update codegen with many of the new changes
the10thWiz Sep 3, 2024
f0f2342
Major fixes for matching and responder
the10thWiz Sep 3, 2024
84ba0b7
Update to pass tests
the10thWiz Sep 3, 2024
6ab2d13
Add FromError
the10thWiz Sep 4, 2024
b55b9c7
Implement TypedError for form errors
the10thWiz Sep 4, 2024
61a4b44
Add Fairing support, and update examples to match new APIs
the10thWiz Sep 4, 2024
61cd326
Fix safety issues & comments
the10thWiz Sep 7, 2024
3fbf1b4
Add derive macro for `TypedError`
the10thWiz Sep 8, 2024
c263a6c
Update Fairings types to fix issues
the10thWiz Sep 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions contrib/dyn_templates/src/fairing.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use rocket::{Rocket, Build, Orbit};
use rocket::fairing::{self, Fairing, Info, Kind};
use rocket::figment::{Source, value::magic::RelativePathBuf};
use rocket::catcher::TypedError;
use rocket::trace::Trace;

use crate::context::{Callback, Context, ContextManager};
Expand Down
26 changes: 14 additions & 12 deletions contrib/dyn_templates/src/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,21 @@ impl Template {
/// extension and a fixed-size body containing the rendered template. If
/// rendering fails, an `Err` of `Status::InternalServerError` is returned.
impl<'r> Responder<'r, 'static> for Template {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
let ctxt = req.rocket()
.state::<ContextManager>()
.ok_or_else(|| {
error!(
"uninitialized template context: missing `Template::fairing()`.\n\
To use templates, you must attach `Template::fairing()`."
);

Status::InternalServerError
})?;
type Error = std::convert::Infallible;
fn respond_to(self, req: &'r Request<'_>) -> response::Outcome<'static, Self::Error> {
if let Some(ctxt) = req.rocket().state::<ContextManager>() {
match self.finalize(&ctxt.context()) {
Ok(v) => v.respond_to(req),
Err(s) => response::Outcome::Forward(s),
}
} else {
error!(
"uninitialized template context: missing `Template::fairing()`.\n\
To use templates, you must attach `Template::fairing()`."
);

self.finalize(&ctxt.context())?.respond_to(req)
response::Outcome::Forward(Status::InternalServerError)
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions contrib/ws/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ impl<'r> FromRequest<'r> for WebSocket {
}

impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
type Error = std::convert::Infallible;
fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> {
Response::build()
.raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
Expand All @@ -250,7 +251,8 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'o
{
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
type Error = std::convert::Infallible;
fn respond_to(self, _: &'r Request<'_>) -> response::Outcome<'o, Self::Error> {
Response::build()
.raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
Expand Down
98 changes: 71 additions & 27 deletions core/codegen/src/attribute/catch/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,64 @@
mod parse;

use devise::ext::SpanDiagnosticExt;
use devise::{Spanned, Result};
use devise::{Result, Spanned};
use proc_macro2::{TokenStream, Span};

use crate::http_codegen::Optional;
use crate::syn_ext::ReturnTypeExt;
use crate::syn_ext::{IdentExt, ReturnTypeExt};
use crate::exports::*;

use self::parse::ErrorGuard;

use super::param::Guard;

fn error_type(guard: &ErrorGuard) -> TokenStream {
let ty = &guard.ty;
quote! {
(#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>())
}
}

fn error_guard_decl(guard: &ErrorGuard) -> TokenStream {
let (ident, ty) = (guard.ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: &#ty = match #_catcher::downcast(__error_init) {
Some(v) => v,
None => return #_Result::Err(#__status),
};
}
}

fn request_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: #ty = match <#ty as #FromError>::from_error(
#__status,
#__req,
__error_init
).await {
#_Result::Ok(__v) => __v,
#_Result::Err(__e) => {
::rocket::trace::info!(
name: "forward",
target: concat!("rocket::codegen::catch::", module_path!()),
parameter = stringify!(#ident),
type_name = stringify!(#ty),
status = __e.code,
"error guard forwarding; trying next catcher"
);

return #_Err(#__status);
},
};
}
}

pub fn _catch(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
) -> Result<TokenStream> {
// Parse and validate all of the user's input.
let catch = parse::Attribute::parse(args.into(), input)?;
let catch = parse::Attribute::parse(args.into(), input.into())?;

// Gather everything we'll need to generate the catcher.
let user_catcher_fn = &catch.function;
Expand All @@ -22,35 +67,28 @@ pub fn _catch(
let status_code = Optional(catch.status.map(|s| s.code));
let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated"));

// Determine the number of parameters that will be passed in.
if catch.function.sig.inputs.len() > 2 {
return Err(catch.function.sig.paren_token.span.join()
.error("invalid number of arguments: must be zero, one, or two")
.help("catchers optionally take `&Request` or `Status, &Request`"));
}

// This ensures that "Responder not implemented" points to the return type.
let return_type_span = catch.function.sig.output.ty()
.map(|ty| ty.span())
.unwrap_or_else(Span::call_site);

// Set the `req` and `status` spans to that of their respective function
// arguments for a more correct `wrong type` error span. `rev` to be cute.
let codegen_args = &[__req, __status];
let inputs = catch.function.sig.inputs.iter().rev()
.zip(codegen_args.iter())
.map(|(fn_arg, codegen_arg)| match fn_arg {
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
}).rev();
let error_guard = catch.error_guard.as_ref().map(error_guard_decl);
let error_type = Optional(catch.error_guard.as_ref().map(error_type));
let request_guards = catch.request_guards.iter().map(request_guard_decl);
let parameter_names = catch.arguments.map.values()
.map(|(ident, _)| ident.rocketized());

// We append `.await` to the function call if this is `async`.
let dot_await = catch.function.sig.asyncness
.map(|a| quote_spanned!(a.span() => .await));

let catcher_response = quote_spanned!(return_type_span => {
let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await;
#_response::Responder::respond_to(___responder, #__req)?
let ___responder = #user_catcher_fn_name(#(#parameter_names),*) #dot_await;
match #_response::Responder::respond_to(___responder, #__req) {
#Outcome::Success(v) => v,
// If the responder fails, we drop any typed error, and convert to 500
#Outcome::Error(_) | #Outcome::Forward(_) => return Err(#Status::InternalServerError),
}
});

// Generate the catcher, keeping the user's input around.
Expand All @@ -68,20 +106,26 @@ pub fn _catch(
fn into_info(self) -> #_catcher::StaticInfo {
fn monomorphized_function<'__r>(
#__status: #Status,
#__req: &'__r #Request<'_>
#__req: &'__r #Request<'_>,
__error_init: #_Option<&'__r (dyn #TypedError<'__r> + '__r)>,
) -> #_catcher::BoxFuture<'__r> {
#_Box::pin(async move {
#error_guard
#(#request_guards)*
let __response = #catcher_response;
#Response::build()
.status(#__status)
.merge(__response)
.ok()
#_Result::Ok(
#Response::build()
.status(#__status)
.merge(__response)
.finalize()
)
})
}

#_catcher::StaticInfo {
name: ::core::stringify!(#user_catcher_fn_name),
code: #status_code,
error_type: #error_type,
handler: monomorphized_function,
location: (::core::file!(), ::core::line!(), ::core::column!()),
}
Expand Down
113 changes: 107 additions & 6 deletions core/codegen/src/attribute/catch/parse.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use devise::ext::SpanDiagnosticExt;
use devise::{MetaItem, Spanned, Result, FromMeta, Diagnostic};
use proc_macro2::TokenStream;
use devise::ext::{SpanDiagnosticExt, TypeExt};
use devise::{Diagnostic, FromMeta, MetaItem, Result, SpanWrapped, Spanned};
use proc_macro2::{Span, TokenStream, Ident};
use quote::ToTokens;

use crate::attribute::param::{Dynamic, Guard};
use crate::name::{ArgumentMap, Arguments, Name};
use crate::proc_macro_ext::Diagnostics;
use crate::syn_ext::FnArgExt;
use crate::{http, http_codegen};

/// This structure represents the parsed `catch` attribute and associated items.
Expand All @@ -10,13 +15,65 @@ pub struct Attribute {
pub status: Option<http::Status>,
/// The function that was decorated with the `catch` attribute.
pub function: syn::ItemFn,
pub arguments: Arguments,
pub error_guard: Option<ErrorGuard>,
pub request_guards: Vec<Guard>,
}

pub struct ErrorGuard {
pub span: Span,
pub name: Name,
pub ident: syn::Ident,
pub ty: syn::Type,
}

impl ErrorGuard {
fn new(param: SpanWrapped<Dynamic>, args: &Arguments) -> Result<Self> {
if let Some((ident, ty)) = args.map.get(&param.name) {
match ty {
syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(Self {
span: param.span(),
name: param.name.clone(),
ident: ident.clone(),
ty: elem.as_ref().clone(),
}),
ty => {
let msg = format!(
"Error argument must be a reference, found `{}`",
ty.to_token_stream()
);
let diag = param.span()
.error("invalid type")
.span_note(ty.span(), msg)
.help(format!("Perhaps use `&{}` instead", ty.to_token_stream()));
Err(diag)
}
}
} else {
let msg = format!("expected argument named `{}` here", param.name);
let diag = param.span().error("unused parameter").span_note(args.span, msg);
Err(diag)
}
}
}

fn status_guard(param: SpanWrapped<Dynamic>, args: &Arguments) -> Result<(Name, Ident)> {
if let Some((ident, _)) = args.map.get(&param.name) {
Ok((param.name.clone(), ident.clone()))
} else {
let msg = format!("expected argument named `{}` here", param.name);
let diag = param.span().error("unused parameter").span_note(args.span, msg);
Err(diag)
}
}

/// We generate a full parser for the meta-item for great error messages.
#[derive(FromMeta)]
struct Meta {
#[meta(naked)]
code: Code,
error: Option<SpanWrapped<Dynamic>>,
status: Option<SpanWrapped<Dynamic>>,
}

/// `Some` if there's a code, `None` if it's `default`.
Expand All @@ -43,16 +100,60 @@ impl FromMeta for Code {

impl Attribute {
pub fn parse(args: TokenStream, input: proc_macro::TokenStream) -> Result<Self> {
let mut diags = Diagnostics::new();

let function: syn::ItemFn = syn::parse(input)
.map_err(Diagnostic::from)
.map_err(|diag| diag.help("`#[catch]` can only be used on functions"))?;

let attr: MetaItem = syn::parse2(quote!(catch(#args)))?;
let status = Meta::from_meta(&attr)
.map(|meta| meta.code.0)
let attr = Meta::from_meta(&attr)
.map(|meta| meta)
.map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \
`#[catch(404)]` or `#[catch(default)]`"))?;

Ok(Attribute { status, function })
let span = function.sig.paren_token.span.join();
let mut arguments = Arguments { map: ArgumentMap::new(), span };
for arg in function.sig.inputs.iter() {
if let Some((ident, ty)) = arg.typed() {
let value = (ident.clone(), ty.with_stripped_lifetimes());
arguments.map.insert(Name::from(ident), value);
} else {
let span = arg.span();
let diag = if arg.wild().is_some() {
span.error("handler arguments must be named")
.help("to name an ignored handler argument, use `_name`")
} else {
span.error("handler arguments must be of the form `ident: Type`")
};

diags.push(diag);
}
}
let error_guard = attr.error.clone()
.map(|p| ErrorGuard::new(p, &arguments))
.and_then(|p| p.map_err(|e| diags.push(e)).ok());
let request_guards = arguments.map.iter()
.filter(|(name, _)| {
let mut all_other_guards = error_guard.iter()
.map(|g| &g.name);

all_other_guards.all(|n| n != *name)
})
.enumerate()
.map(|(index, (name, (ident, ty)))| Guard {
source: Dynamic { index, name: name.clone(), trailing: false },
fn_ident: ident.clone(),
ty: ty.clone(),
})
.collect();

diags.head_err_or(Attribute {
status: attr.code.0,
function,
arguments,
error_guard,
request_guards,
})
}
}
Loading