Skip to content

Commit ea0f756

Browse files
authored
Unrolled build for rust-lang#138314
Rollup merge of rust-lang#138314 - haenoe:autodiff-inner-function, r=ZuseZ4 fix usage of `autodiff` macro with inner functions This PR adds additional handling into the expansion step of the `std::autodiff` macro (in `compiler/rustc_builtin_macros/src/autodiff.rs`), which allows the macro to be applied to inner functions. ```rust #![feature(autodiff)] use std::autodiff::autodiff; fn main() { #[autodiff(d_inner, Forward, Dual, DualOnly)] fn inner(x: f32) -> f32 { x * x } } ``` Previously, the compiler didn't allow this due to only handling `Annotatable::Item` and `Annotatable::AssocItem` and missing the handling of `Annotatable::Stmt`. This resulted in the rather generic error ``` error: autodiff must be applied to function --> src/main.rs:6:5 | 6 | / fn inner(x: f32) -> f32 { 7 | | x * x 8 | | } | |_____^ error: could not compile `enzyme-test` (bin "enzyme-test") due to 1 previous error ``` This issue was originally reported [here](EnzymeAD#184). Quick question: would it make sense to add a ui test to ensure there is no regression on this? This is my first contribution, so I'm extra grateful for any piece of feedback!! :D r? `@oli-obk` Tracking issue for autodiff: rust-lang#124509
2 parents b9856b6 + bf69443 commit ea0f756

File tree

3 files changed

+109
-48
lines changed

3 files changed

+109
-48
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

+77-48
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mod llvm_enzyme {
1717
use rustc_ast::visit::AssocCtxt::*;
1818
use rustc_ast::{
1919
self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
20-
MetaItemInner, PatKind, QSelf, TyKind,
20+
MetaItemInner, PatKind, QSelf, TyKind, Visibility,
2121
};
2222
use rustc_expand::base::{Annotatable, ExtCtxt};
2323
use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -72,6 +72,16 @@ mod llvm_enzyme {
7272
}
7373
}
7474

75+
// Get information about the function the macro is applied to
76+
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
77+
match &iitem.kind {
78+
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
79+
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
80+
}
81+
_ => None,
82+
}
83+
}
84+
7585
pub(crate) fn from_ast(
7686
ecx: &mut ExtCtxt<'_>,
7787
meta_item: &ThinVec<MetaItemInner>,
@@ -199,32 +209,26 @@ mod llvm_enzyme {
199209
return vec![item];
200210
}
201211
let dcx = ecx.sess.dcx();
202-
// first get the annotable item:
203-
let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
204-
Annotatable::Item(iitem) => {
205-
let (ident, sig) = match &iitem.kind {
206-
ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
207-
_ => {
208-
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
209-
return vec![item];
210-
}
211-
};
212-
(*ident, sig.clone(), false)
213-
}
212+
213+
// first get information about the annotable item:
214+
let Some((vis, sig, primal)) = (match &item {
215+
Annotatable::Item(iitem) => extract_item_info(iitem),
216+
Annotatable::Stmt(stmt) => match &stmt.kind {
217+
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
218+
_ => None,
219+
},
214220
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
215-
let (ident, sig) = match &assoc_item.kind {
216-
ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
217-
_ => {
218-
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
219-
return vec![item];
221+
match &assoc_item.kind {
222+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
223+
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
220224
}
221-
};
222-
(*ident, sig.clone(), true)
223-
}
224-
_ => {
225-
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
226-
return vec![item];
225+
_ => None,
226+
}
227227
}
228+
_ => None,
229+
}) else {
230+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
231+
return vec![item];
228232
};
229233

230234
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
@@ -238,15 +242,6 @@ mod llvm_enzyme {
238242
let has_ret = has_ret(&sig.decl.output);
239243
let sig_span = ecx.with_call_site_ctxt(sig.span);
240244

241-
let vis = match &item {
242-
Annotatable::Item(iitem) => iitem.vis.clone(),
243-
Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
244-
_ => {
245-
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
246-
return vec![item];
247-
}
248-
};
249-
250245
// create TokenStream from vec elemtents:
251246
// meta_item doesn't have a .tokens field
252247
let mut ts: Vec<TokenTree> = vec![];
@@ -379,6 +374,22 @@ mod llvm_enzyme {
379374
}
380375
Annotatable::AssocItem(assoc_item.clone(), i)
381376
}
377+
Annotatable::Stmt(ref mut stmt) => {
378+
match stmt.kind {
379+
ast::StmtKind::Item(ref mut iitem) => {
380+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
381+
iitem.attrs.push(attr);
382+
}
383+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
384+
{
385+
iitem.attrs.push(inline_never.clone());
386+
}
387+
}
388+
_ => unreachable!("stmt kind checked previously"),
389+
};
390+
391+
Annotatable::Stmt(stmt.clone())
392+
}
382393
_ => {
383394
unreachable!("annotatable kind checked previously")
384395
}
@@ -389,22 +400,40 @@ mod llvm_enzyme {
389400
delim: rustc_ast::token::Delimiter::Parenthesis,
390401
tokens: ts,
391402
});
403+
392404
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
393-
let d_annotatable = if is_impl {
394-
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
395-
let d_fn = P(ast::AssocItem {
396-
attrs: thin_vec![d_attr, inline_never],
397-
id: ast::DUMMY_NODE_ID,
398-
span,
399-
vis,
400-
kind: assoc_item,
401-
tokens: None,
402-
});
403-
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
404-
} else {
405-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
406-
d_fn.vis = vis;
407-
Annotatable::Item(d_fn)
405+
let d_annotatable = match &item {
406+
Annotatable::AssocItem(_, _) => {
407+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
408+
let d_fn = P(ast::AssocItem {
409+
attrs: thin_vec![d_attr, inline_never],
410+
id: ast::DUMMY_NODE_ID,
411+
span,
412+
vis,
413+
kind: assoc_item,
414+
tokens: None,
415+
});
416+
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
417+
}
418+
Annotatable::Item(_) => {
419+
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
420+
d_fn.vis = vis;
421+
422+
Annotatable::Item(d_fn)
423+
}
424+
Annotatable::Stmt(_) => {
425+
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
426+
d_fn.vis = vis;
427+
428+
Annotatable::Stmt(P(ast::Stmt {
429+
id: ast::DUMMY_NODE_ID,
430+
kind: ast::StmtKind::Item(d_fn),
431+
span,
432+
}))
433+
}
434+
_ => {
435+
unreachable!("item kind checked previously")
436+
}
408437
};
409438

410439
return vec![orig_annotatable, d_annotatable];

tests/pretty/autodiff_forward.pp

+23
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
// Make sure, that we add the None for the default return.
3030

3131

32+
// We want to make sure that we can use the macro for functions defined inside of functions
33+
3234
::core::panicking::panic("not implemented")
3335
}
3436
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -158,4 +160,25 @@
158160
::core::hint::black_box((bx_0,));
159161
::core::hint::black_box(<f32>::default())
160162
}
163+
pub fn f9() {
164+
#[rustc_autodiff]
165+
#[inline(never)]
166+
fn inner(x: f32) -> f32 { x * x }
167+
#[rustc_autodiff(Forward, 1, Dual, Dual)]
168+
#[inline(never)]
169+
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
170+
unsafe { asm!("NOP", options(pure, nomem)); };
171+
::core::hint::black_box(inner(x));
172+
::core::hint::black_box((bx_0,));
173+
::core::hint::black_box(<(f32, f32)>::default())
174+
}
175+
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
176+
#[inline(never)]
177+
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
178+
unsafe { asm!("NOP", options(pure, nomem)); };
179+
::core::hint::black_box(inner(x));
180+
::core::hint::black_box((bx_0,));
181+
::core::hint::black_box(<f32>::default())
182+
}
183+
}
161184
fn main() {}

tests/pretty/autodiff_forward.rs

+9
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,13 @@ fn f8(x: &f32) -> f32 {
5454
unimplemented!()
5555
}
5656

57+
// We want to make sure that we can use the macro for functions defined inside of functions
58+
pub fn f9() {
59+
#[autodiff(d_inner_1, Forward, Dual, DualOnly)]
60+
#[autodiff(d_inner_2, Forward, Dual, Dual)]
61+
fn inner(x: f32) -> f32 {
62+
x * x
63+
}
64+
}
65+
5766
fn main() {}

0 commit comments

Comments
 (0)