Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff batching #137880

Merged
merged 5 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
/// - Calling the function 50 times with a batch size of 2
/// - Calling the function 25 times with a batch size of 4,
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
/// cache locality, better re-usal of primal values, and other optimizations.
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
/// experiments for now and focus on documenting the implications of a large width.
pub width: u32,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}
Expand Down Expand Up @@ -222,13 +233,15 @@ impl AutoDiffAttrs {
pub const fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
width: 0,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
width: 0,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_builtin_macros/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ builtin_macros_autodiff_ret_activity = invalid return activity {$act} in {$mode}
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
.label = not applicable here
.label2 = not a `struct`, `enum` or `union`
Expand Down
339 changes: 213 additions & 126 deletions compiler/rustc_builtin_macros/src/autodiff.rs

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions compiler/rustc_builtin_macros/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ mod autodiff {
pub(crate) mode: String,
}

#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_width)]
pub(crate) struct AutoDiffInvalidWidth {
#[primary_span]
pub(crate) span: Span,
pub(crate) width: u128,
}

#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff)]
pub(crate) struct AutoDiffInvalidApplication {
Expand Down
12 changes: 10 additions & 2 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
}
// We handle this below
config::AutoDiff::PrintModAfter => {}
// We handle this below
config::AutoDiff::PrintModFinal => {}
// This is required and already checked
config::AutoDiff::Enable => {}
}
Expand Down Expand Up @@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager(
}

if cfg!(llvm_enzyme) && enable_ad {
// This is the post-autodiff IR, mainly used for testing and educational purposes.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}

let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}

// This is the final IR, so people should be able to inspect the optimized autodiff output.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
// This is the final IR, so people should be able to inspect the optimized autodiff output,
// for manual inspection.
if config.autodiff.contains(&config::AutoDiff::PrintModFinal) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
}
Expand Down
199 changes: 160 additions & 39 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
use rustc_errors::FatalError;
use rustc_middle::bug;
use tracing::{debug, trace};

use crate::back::write::llvm_err;
Expand All @@ -18,21 +20,42 @@ use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};

fn get_params(fnc: &Value) -> Vec<&Value> {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
unsafe {
let param_num = llvm::LLVMCountParams(fnc) as usize;
let mut fnc_args: Vec<&Value> = vec![];
fnc_args.reserve(param_num);
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
fnc_args.set_len(param_num);
fnc_args
}
fnc_args
}

fn has_sret(fnc: &Value) -> bool {
let num_args = llvm::LLVMCountParams(fnc) as usize;
if num_args == 0 {
false
} else {
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
}
}

// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
// original inputs, as well as metadata and the additional shadow arguments.
// This function matches the arguments from the outer function to the inner enzyme call.
//
// This function also considers that Rust level arguments not always match the llvm-ir level
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
// need to match those.
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
// using iterators and peek()?
fn match_args_from_caller_to_enzyme<'ll>(
cx: &SimpleCx<'ll>,
width: u32,
args: &mut Vec<&'ll llvm::Value>,
inputs: &[DiffActivity],
outer_args: &[&'ll llvm::Value],
has_sret: bool,
) {
debug!("matching autodiff arguments");
// We now handle the issue that Rust level arguments not always match the llvm-ir level
Expand All @@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
let mut outer_pos: usize = 0;
let mut activity_pos = 0;

if has_sret {
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
// inner function will still return something. We increase our outer_pos by one,
// and once we're done with all other args we will take the return of the inner call and
// update the sret pointer with it
outer_pos = 1;
}

let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
Expand Down Expand Up @@ -92,39 +123,40 @@ fn match_args_from_caller_to_enzyme<'ll>(
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
assert!(unsafe {
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
});
let next_outer_arg2 = outer_args[outer_pos + 2];
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
assert!(unsafe {
llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer
});
let next_outer_arg3 = outer_args[outer_pos + 3];
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
assert!(unsafe {
llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer
});
args.push(next_outer_arg2);
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);

for i in 0..(width as usize) {
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
args.push(next_outer_arg2);
}
args.push(cx.get_metadata_value(enzyme_const));
args.push(next_outer_arg);
outer_pos += 4;
outer_pos += 2 + 2 * width as usize;
activity_pos += 2;
} else {
// A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...).
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
{
assert!(
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) }
== llvm::TypeKind::Pointer
);
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
}
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
args.push(next_outer_arg);
outer_pos += 2;
activity_pos += 1;

// Now, if width > 1, we need to account for that
for _ in 1..width {
let next_outer_arg = outer_args[outer_pos];
args.push(next_outer_arg);
outer_pos += 1;
}
}
} else {
// We do not differentiate with resprect to this argument.
Expand All @@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
}
}

// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
// Beyond sret, this article describes our challenges nicely:
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
fn compute_enzyme_fn_ty<'ll>(
cx: &SimpleCx<'ll>,
attrs: &AutoDiffAttrs,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
) -> &'ll llvm::Type {
let fn_ty = cx.get_type_of_global(outer_fn);
let mut ret_ty = cx.get_return_type(fn_ty);

let has_sret = has_sret(outer_fn);

if has_sret {
// Now we don't just forward the return type, so we have to figure it out based on the
// primal return type, in combination with the autodiff settings.
let fn_ty = cx.get_type_of_global(fn_to_diff);
let inner_ret_ty = cx.get_return_type(fn_ty);

let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
if inner_ret_ty == void_ty {
// This indicates that even the inner function has an sret.
// Right now I only look for an sret in the outer function.
// This *probably* needs some extra handling, but I never ran
// into such a case. So I'll wait for user reports to have a test case.
bug!("sret in inner function");
}

if attrs.width == 1 {
todo!("Handle sret for scalar ad");
} else {
// First we check if we also have to deal with the primal return.
match attrs.mode {
DiffMode::Forward => match attrs.ret_activity {
DiffActivity::Dual => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
ret_ty = arr_ty;
}
DiffActivity::DualOnly => {
let arr_ty =
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
ret_ty = arr_ty;
}
DiffActivity::Const => {
todo!("Not sure, do we need to do something here?");
}
_ => {
bug!("unreachable");
}
},
DiffMode::Reverse => {
todo!("Handle sret for reverse mode");
}
_ => {
bug!("unreachable");
}
}
}
}

// LLVM can figure out the input types on it's own, so we take a shortcut here.
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
}

/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
/// function with expected naming and calling conventions[^1] which will be
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
Expand Down Expand Up @@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
// }
// ```
unsafe {
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
// arguments. We do however need to declare them with their correct return type.
// We already figured the correct return type out in our frontend, when generating the outer_fn,
// so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
let ret_ty = llvm::LLVMGetReturnType(fn_ty);

// LLVM can figure out the input types on it's own, so we take a shortcut here.
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);

//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_simple_fn(
Expand Down Expand Up @@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
args.push(cx.get_metadata_value(enzyme_primal_ret));
}
if attrs.width > 1 {
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
args.push(cx.get_metadata_value(enzyme_width));
args.push(cx.get_const_i64(attrs.width as u64));
}

let has_sret = has_sret(outer_fn);
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
match_args_from_caller_to_enzyme(&cx, &mut args, &attrs.input_activity, &outer_args);
match_args_from_caller_to_enzyme(
&cx,
attrs.width,
&mut args,
&attrs.input_activity,
&outer_args,
has_sret,
);

let call = builder.call(enzyme_ty, ad_fn, &args, None);

// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attachted to it, but we just created this code oota. Given that the
// metadata attached to it, but we just created this code oota. Given that the
// differentiated function already has partly confusing metadata, and given that this
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
// dummy code which we inserted at a higher level.
Expand All @@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
// Now that we copied the metadata, get rid of dummy code.
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);

if cx.val_ty(call) == cx.type_void() {
if cx.val_ty(call) == cx.type_void() || has_sret {
if has_sret {
// This is what we already have in our outer_fn (shortened):
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
// <Here we are, we want to add the following two lines>
// store [4 x double] %7, ptr %0, align 8
// ret void
// }

// now store the result of the enzyme call into the sret pointer.
let sret_ptr = outer_args[0];
let call_ty = cx.val_ty(call);
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
}
builder.ret_void();
} else {
builder.ret(call);
Expand Down Expand Up @@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
if !diff_items.is_empty()
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
{
let dcx = cgcx.create_dcx();
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
}

// Before dumping the module, we want all the TypeTrees to become part of the module.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ impl<'ll> CodegenCx<'ll, '_> {
let val_llty = self.val_ty(v);

let g = self.get_static_inner(def_id, val_llty);
let llty = llvm::LLVMGlobalGetValueType(g);
let llty = self.get_type_of_global(g);

let g = if val_llty == llty {
g
Expand Down
Loading
Loading