Skip to content

Commit c6bf3a0

Browse files
authored
Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching Enzyme supports batching, which is especially known from the ML side when training neural networks. There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights. That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations. Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size, and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of ```rs for i in 0..100 { df(x[i], y[i], 1234); } ``` You can now do ```rs for i in 0..100.step_by(4) { df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234); } ``` which will give the same results, but allows better compiler optimizations. See the testcase for details. There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days. I will also add more tests for both modes. For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU I'll also add some other docs to the dev guide and user docs in another PR. r? ghost Tracking: - #124509 - #135283
2 parents 2e4e196 + 89d8948 commit c6bf3a0

File tree

21 files changed

+727
-233
lines changed

21 files changed

+727
-233
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+13
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
7777
/// e.g. in the [JAX
7878
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979
pub mode: DiffMode,
80+
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
81+
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
82+
/// - Calling the function 50 times with a batch size of 2
83+
/// - Calling the function 25 times with a batch size of 4,
84+
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
85+
/// cache locality, better re-usal of primal values, and other optimizations.
86+
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
87+
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
88+
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
89+
/// experiments for now and focus on documenting the implications of a large width.
90+
pub width: u32,
8091
pub ret_activity: DiffActivity,
8192
pub input_activity: Vec<DiffActivity>,
8293
}
@@ -222,13 +233,15 @@ impl AutoDiffAttrs {
222233
pub const fn error() -> Self {
223234
AutoDiffAttrs {
224235
mode: DiffMode::Error,
236+
width: 0,
225237
ret_activity: DiffActivity::None,
226238
input_activity: Vec::new(),
227239
}
228240
}
229241
pub fn source() -> Self {
230242
AutoDiffAttrs {
231243
mode: DiffMode::Source,
244+
width: 0,
232245
ret_activity: DiffActivity::None,
233246
input_activity: Vec::new(),
234247
}

compiler/rustc_builtin_macros/messages.ftl

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
7979
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
8080
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
8181
82+
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
8283
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
8384
.label = not applicable here
8485
.label2 = not a `struct`, `enum` or `union`

compiler/rustc_builtin_macros/src/autodiff.rs

+213-126
Large diffs are not rendered by default.

compiler/rustc_builtin_macros/src/errors.rs

+8
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,14 @@ mod autodiff {
202202
pub(crate) mode: String,
203203
}
204204

205+
#[derive(Diagnostic)]
206+
#[diag(builtin_macros_autodiff_width)]
207+
pub(crate) struct AutoDiffInvalidWidth {
208+
#[primary_span]
209+
pub(crate) span: Span,
210+
pub(crate) width: u128,
211+
}
212+
205213
#[derive(Diagnostic)]
206214
#[diag(builtin_macros_autodiff)]
207215
pub(crate) struct AutoDiffInvalidApplication {

compiler/rustc_codegen_llvm/src/back/lto.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
610610
}
611611
// We handle this below
612612
config::AutoDiff::PrintModAfter => {}
613+
// We handle this below
614+
config::AutoDiff::PrintModFinal => {}
613615
// This is required and already checked
614616
config::AutoDiff::Enable => {}
615617
}
@@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager(
657659
}
658660

659661
if cfg!(llvm_enzyme) && enable_ad {
662+
// This is the post-autodiff IR, mainly used for testing and educational purposes.
663+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
664+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
665+
}
666+
660667
let opt_stage = llvm::OptStage::FatLTO;
661668
let stage = write::AutodiffStage::PostAD;
662669
unsafe {
663670
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
664671
}
665672

666-
// This is the final IR, so people should be able to inspect the optimized autodiff output.
667-
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
673+
// This is the final IR, so people should be able to inspect the optimized autodiff output,
674+
// for manual inspection.
675+
if config.autodiff.contains(&config::AutoDiff::PrintModFinal) {
668676
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
669677
}
670678
}

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+160-39
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
6+
use rustc_codegen_ssa::common::TypeKind;
7+
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
78
use rustc_errors::FatalError;
9+
use rustc_middle::bug;
810
use tracing::{debug, trace};
911

1012
use crate::back::write::llvm_err;
@@ -18,21 +20,42 @@ use crate::value::Value;
1820
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
1921

2022
fn get_params(fnc: &Value) -> Vec<&Value> {
23+
let param_num = llvm::LLVMCountParams(fnc) as usize;
24+
let mut fnc_args: Vec<&Value> = vec![];
25+
fnc_args.reserve(param_num);
2126
unsafe {
22-
let param_num = llvm::LLVMCountParams(fnc) as usize;
23-
let mut fnc_args: Vec<&Value> = vec![];
24-
fnc_args.reserve(param_num);
2527
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
2628
fnc_args.set_len(param_num);
27-
fnc_args
2829
}
30+
fnc_args
2931
}
3032

33+
fn has_sret(fnc: &Value) -> bool {
34+
let num_args = llvm::LLVMCountParams(fnc) as usize;
35+
if num_args == 0 {
36+
false
37+
} else {
38+
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
39+
}
40+
}
41+
42+
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
43+
// original inputs, as well as metadata and the additional shadow arguments.
44+
// This function matches the arguments from the outer function to the inner enzyme call.
45+
//
46+
// This function also considers that Rust level arguments not always match the llvm-ir level
47+
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
48+
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
49+
// need to match those.
50+
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
51+
// using iterators and peek()?
3152
fn match_args_from_caller_to_enzyme<'ll>(
3253
cx: &SimpleCx<'ll>,
54+
width: u32,
3355
args: &mut Vec<&'ll llvm::Value>,
3456
inputs: &[DiffActivity],
3557
outer_args: &[&'ll llvm::Value],
58+
has_sret: bool,
3659
) {
3760
debug!("matching autodiff arguments");
3861
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
4467
let mut outer_pos: usize = 0;
4568
let mut activity_pos = 0;
4669

70+
if has_sret {
71+
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72+
// inner function will still return something. We increase our outer_pos by one,
73+
// and once we're done with all other args we will take the return of the inner call and
74+
// update the sret pointer with it
75+
outer_pos = 1;
76+
}
77+
4778
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
4879
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
4980
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
@@ -92,39 +123,40 @@ fn match_args_from_caller_to_enzyme<'ll>(
92123
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
93124
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
94125
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
95-
assert!(unsafe {
96-
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
97-
});
98-
let next_outer_arg2 = outer_args[outer_pos + 2];
99-
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
100-
assert!(unsafe {
101-
llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
102-
});
103-
let next_outer_arg3 = outer_args[outer_pos + 3];
104-
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
105-
assert!(unsafe {
106-
llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
107-
});
108-
args.push(next_outer_arg2);
126+
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
127+
128+
for i in 0..(width as usize) {
129+
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
130+
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131+
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
132+
let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
133+
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
134+
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
135+
args.push(next_outer_arg2);
136+
}
109137
args.push(cx.get_metadata_value(enzyme_const));
110138
args.push(next_outer_arg);
111-
outer_pos += 4;
139+
outer_pos += 2 + 2 * width as usize;
112140
activity_pos += 2;
113141
} else {
114142
// A duplicated pointer will have the following two outer_fn arguments:
115143
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
116144
// (..., metadata! enzyme_dup, ptr, ptr, ...).
117145
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
118146
{
119-
assert!(
120-
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) }
121-
== llvm::TypeKind::Pointer
122-
);
147+
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
123148
}
124149
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
125150
args.push(next_outer_arg);
126151
outer_pos += 2;
127152
activity_pos += 1;
153+
154+
// Now, if width > 1, we need to account for that
155+
for _ in 1..width {
156+
let next_outer_arg = outer_args[outer_pos];
157+
args.push(next_outer_arg);
158+
outer_pos += 1;
159+
}
128160
}
129161
} else {
130162
// We do not differentiate with resprect to this argument.
@@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
135167
}
136168
}
137169

170+
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
171+
// arguments. We do however need to declare them with their correct return type.
172+
// We already figured the correct return type out in our frontend, when generating the outer_fn,
173+
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
174+
// Beyond sret, this article describes our challenges nicely:
175+
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
176+
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
177+
fn compute_enzyme_fn_ty<'ll>(
178+
cx: &SimpleCx<'ll>,
179+
attrs: &AutoDiffAttrs,
180+
fn_to_diff: &'ll Value,
181+
outer_fn: &'ll Value,
182+
) -> &'ll llvm::Type {
183+
let fn_ty = cx.get_type_of_global(outer_fn);
184+
let mut ret_ty = cx.get_return_type(fn_ty);
185+
186+
let has_sret = has_sret(outer_fn);
187+
188+
if has_sret {
189+
// Now we don't just forward the return type, so we have to figure it out based on the
190+
// primal return type, in combination with the autodiff settings.
191+
let fn_ty = cx.get_type_of_global(fn_to_diff);
192+
let inner_ret_ty = cx.get_return_type(fn_ty);
193+
194+
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
195+
if inner_ret_ty == void_ty {
196+
// This indicates that even the inner function has an sret.
197+
// Right now I only look for an sret in the outer function.
198+
// This *probably* needs some extra handling, but I never ran
199+
// into such a case. So I'll wait for user reports to have a test case.
200+
bug!("sret in inner function");
201+
}
202+
203+
if attrs.width == 1 {
204+
todo!("Handle sret for scalar ad");
205+
} else {
206+
// First we check if we also have to deal with the primal return.
207+
match attrs.mode {
208+
DiffMode::Forward => match attrs.ret_activity {
209+
DiffActivity::Dual => {
210+
let arr_ty =
211+
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
212+
ret_ty = arr_ty;
213+
}
214+
DiffActivity::DualOnly => {
215+
let arr_ty =
216+
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
217+
ret_ty = arr_ty;
218+
}
219+
DiffActivity::Const => {
220+
todo!("Not sure, do we need to do something here?");
221+
}
222+
_ => {
223+
bug!("unreachable");
224+
}
225+
},
226+
DiffMode::Reverse => {
227+
todo!("Handle sret for reverse mode");
228+
}
229+
_ => {
230+
bug!("unreachable");
231+
}
232+
}
233+
}
234+
}
235+
236+
// LLVM can figure out the input types on it's own, so we take a shortcut here.
237+
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
238+
}
239+
138240
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139241
/// function with expected naming and calling conventions[^1] which will be
140242
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
197299
// }
198300
// ```
199301
unsafe {
200-
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201-
// arguments. We do however need to declare them with their correct return type.
202-
// We already figured the correct return type out in our frontend, when generating the outer_fn,
203-
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204-
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
205-
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
206-
207-
// LLVM can figure out the input types on it's own, so we take a shortcut here.
208-
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
302+
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
209303

210-
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
304+
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211305
// think a bit more about what should go here.
212306
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
213307
let ad_fn = declare_simple_fn(
@@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
240334
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
241335
args.push(cx.get_metadata_value(enzyme_primal_ret));
242336
}
337+
if attrs.width > 1 {
338+
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
339+
args.push(cx.get_metadata_value(enzyme_width));
340+
args.push(cx.get_const_i64(attrs.width as u64));
341+
}
243342

343+
let has_sret = has_sret(outer_fn);
244344
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
245-
match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
345+
match_args_from_caller_to_enzyme(
346+
&cx,
347+
attrs.width,
348+
&mut args,
349+
&attrs.input_activity,
350+
&outer_args,
351+
has_sret,
352+
);
246353

247354
let call = builder.call(enzyme_ty, ad_fn, &args, None);
248355

249356
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
250-
// metadata attachted to it, but we just created this code oota. Given that the
357+
// metadata attached to it, but we just created this code oota. Given that the
251358
// differentiated function already has partly confusing metadata, and given that this
252359
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
253360
// dummy code which we inserted at a higher level.
@@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
268375
// Now that we copied the metadata, get rid of dummy code.
269376
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
270377

271-
if cx.val_ty(call) == cx.type_void() {
378+
if cx.val_ty(call) == cx.type_void() || has_sret {
379+
if has_sret {
380+
// This is what we already have in our outer_fn (shortened):
381+
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
382+
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
383+
// <Here we are, we want to add the following two lines>
384+
// store [4 x double] %7, ptr %0, align 8
385+
// ret void
386+
// }
387+
388+
// now store the result of the enzyme call into the sret pointer.
389+
let sret_ptr = outer_args[0];
390+
let call_ty = cx.val_ty(call);
391+
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
392+
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
393+
}
272394
builder.ret_void();
273395
} else {
274396
builder.ret(call);
@@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
300422
if !diff_items.is_empty()
301423
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
302424
{
303-
let dcx = cgcx.create_dcx();
304-
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
425+
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
305426
}
306427

307428
// Before dumping the module, we want all the TypeTrees to become part of the module.

compiler/rustc_codegen_llvm/src/consts.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ impl<'ll> CodegenCx<'ll, '_> {
430430
let val_llty = self.val_ty(v);
431431

432432
let g = self.get_static_inner(def_id, val_llty);
433-
let llty = llvm::LLVMGlobalGetValueType(g);
433+
let llty = self.get_type_of_global(g);
434434

435435
let g = if val_llty == llty {
436436
g

0 commit comments

Comments
 (0)