@@ -3,8 +3,10 @@ use std::ptr;
3
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods as _;
6
+ use rustc_codegen_ssa:: common:: TypeKind ;
7
+ use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7
8
use rustc_errors:: FatalError ;
9
+ use rustc_middle:: bug;
8
10
use tracing:: { debug, trace} ;
9
11
10
12
use crate :: back:: write:: llvm_err;
@@ -18,21 +20,42 @@ use crate::value::Value;
18
20
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
19
21
20
22
fn get_params ( fnc : & Value ) -> Vec < & Value > {
23
+ let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
24
+ let mut fnc_args: Vec < & Value > = vec ! [ ] ;
25
+ fnc_args. reserve ( param_num) ;
21
26
unsafe {
22
- let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
23
- let mut fnc_args: Vec < & Value > = vec ! [ ] ;
24
- fnc_args. reserve ( param_num) ;
25
27
llvm:: LLVMGetParams ( fnc, fnc_args. as_mut_ptr ( ) ) ;
26
28
fnc_args. set_len ( param_num) ;
27
- fnc_args
28
29
}
30
+ fnc_args
29
31
}
30
32
33
+ fn has_sret ( fnc : & Value ) -> bool {
34
+ let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
35
+ if num_args == 0 {
36
+ false
37
+ } else {
38
+ unsafe { llvm:: LLVMRustHasAttributeAtIndex ( fnc, 0 , llvm:: AttributeKind :: StructRet ) }
39
+ }
40
+ }
41
+
42
+ // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
43
+ // original inputs, as well as metadata and the additional shadow arguments.
44
+ // This function matches the arguments from the outer function to the inner enzyme call.
45
+ //
46
+ // This function also considers that Rust level arguments not always match the llvm-ir level
47
+ // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
48
+ // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
49
+ // need to match those.
50
+ // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
51
+ // using iterators and peek()?
31
52
fn match_args_from_caller_to_enzyme < ' ll > (
32
53
cx : & SimpleCx < ' ll > ,
54
+ width : u32 ,
33
55
args : & mut Vec < & ' ll llvm:: Value > ,
34
56
inputs : & [ DiffActivity ] ,
35
57
outer_args : & [ & ' ll llvm:: Value ] ,
58
+ has_sret : bool ,
36
59
) {
37
60
debug ! ( "matching autodiff arguments" ) ;
38
61
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -44,6 +67,14 @@ fn match_args_from_caller_to_enzyme<'ll>(
44
67
let mut outer_pos: usize = 0 ;
45
68
let mut activity_pos = 0 ;
46
69
70
+ if has_sret {
71
+ // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72
+ // inner function will still return something. We increase our outer_pos by one,
73
+ // and once we're done with all other args we will take the return of the inner call and
74
+ // update the sret pointer with it
75
+ outer_pos = 1 ;
76
+ }
77
+
47
78
let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
48
79
let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
49
80
let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -92,39 +123,40 @@ fn match_args_from_caller_to_enzyme<'ll>(
92
123
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
93
124
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
94
125
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
95
- assert ! ( unsafe {
96
- llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
97
- } ) ;
98
- let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
99
- let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
100
- assert ! ( unsafe {
101
- llvm:: LLVMRustGetTypeKind ( next_outer_ty2) == llvm:: TypeKind :: Pointer
102
- } ) ;
103
- let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
104
- let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
105
- assert ! ( unsafe {
106
- llvm:: LLVMRustGetTypeKind ( next_outer_ty3) == llvm:: TypeKind :: Integer
107
- } ) ;
108
- args. push ( next_outer_arg2) ;
126
+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Integer ) ;
127
+
128
+ for i in 0 ..( width as usize ) {
129
+ let next_outer_arg2 = outer_args[ outer_pos + 2 * ( i + 1 ) ] ;
130
+ let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131
+ assert_eq ! ( cx. type_kind( next_outer_ty2) , TypeKind :: Pointer ) ;
132
+ let next_outer_arg3 = outer_args[ outer_pos + 2 * ( i + 1 ) + 1 ] ;
133
+ let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
134
+ assert_eq ! ( cx. type_kind( next_outer_ty3) , TypeKind :: Integer ) ;
135
+ args. push ( next_outer_arg2) ;
136
+ }
109
137
args. push ( cx. get_metadata_value ( enzyme_const) ) ;
110
138
args. push ( next_outer_arg) ;
111
- outer_pos += 4 ;
139
+ outer_pos += 2 + 2 * width as usize ;
112
140
activity_pos += 2 ;
113
141
} else {
114
142
// A duplicated pointer will have the following two outer_fn arguments:
115
143
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
116
144
// (..., metadata! enzyme_dup, ptr, ptr, ...).
117
145
if matches ! ( diff_activity, DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly )
118
146
{
119
- assert ! (
120
- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty) }
121
- == llvm:: TypeKind :: Pointer
122
- ) ;
147
+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Pointer ) ;
123
148
}
124
149
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
125
150
args. push ( next_outer_arg) ;
126
151
outer_pos += 2 ;
127
152
activity_pos += 1 ;
153
+
154
+ // Now, if width > 1, we need to account for that
155
+ for _ in 1 ..width {
156
+ let next_outer_arg = outer_args[ outer_pos] ;
157
+ args. push ( next_outer_arg) ;
158
+ outer_pos += 1 ;
159
+ }
128
160
}
129
161
} else {
130
162
// We do not differentiate with resprect to this argument.
@@ -135,6 +167,76 @@ fn match_args_from_caller_to_enzyme<'ll>(
135
167
}
136
168
}
137
169
170
+ // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
171
+ // arguments. We do however need to declare them with their correct return type.
172
+ // We already figured the correct return type out in our frontend, when generating the outer_fn,
173
+ // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
174
+ // Beyond sret, this article describes our challenges nicely:
175
+ // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
176
+ // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
177
+ fn compute_enzyme_fn_ty < ' ll > (
178
+ cx : & SimpleCx < ' ll > ,
179
+ attrs : & AutoDiffAttrs ,
180
+ fn_to_diff : & ' ll Value ,
181
+ outer_fn : & ' ll Value ,
182
+ ) -> & ' ll llvm:: Type {
183
+ let fn_ty = cx. get_type_of_global ( outer_fn) ;
184
+ let mut ret_ty = cx. get_return_type ( fn_ty) ;
185
+
186
+ let has_sret = has_sret ( outer_fn) ;
187
+
188
+ if has_sret {
189
+ // Now we don't just forward the return type, so we have to figure it out based on the
190
+ // primal return type, in combination with the autodiff settings.
191
+ let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
192
+ let inner_ret_ty = cx. get_return_type ( fn_ty) ;
193
+
194
+ let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
195
+ if inner_ret_ty == void_ty {
196
+ // This indicates that even the inner function has an sret.
197
+ // Right now I only look for an sret in the outer function.
198
+ // This *probably* needs some extra handling, but I never ran
199
+ // into such a case. So I'll wait for user reports to have a test case.
200
+ bug ! ( "sret in inner function" ) ;
201
+ }
202
+
203
+ if attrs. width == 1 {
204
+ todo ! ( "Handle sret for scalar ad" ) ;
205
+ } else {
206
+ // First we check if we also have to deal with the primal return.
207
+ match attrs. mode {
208
+ DiffMode :: Forward => match attrs. ret_activity {
209
+ DiffActivity :: Dual => {
210
+ let arr_ty =
211
+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
212
+ ret_ty = arr_ty;
213
+ }
214
+ DiffActivity :: DualOnly => {
215
+ let arr_ty =
216
+ unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
217
+ ret_ty = arr_ty;
218
+ }
219
+ DiffActivity :: Const => {
220
+ todo ! ( "Not sure, do we need to do something here?" ) ;
221
+ }
222
+ _ => {
223
+ bug ! ( "unreachable" ) ;
224
+ }
225
+ } ,
226
+ DiffMode :: Reverse => {
227
+ todo ! ( "Handle sret for reverse mode" ) ;
228
+ }
229
+ _ => {
230
+ bug ! ( "unreachable" ) ;
231
+ }
232
+ }
233
+ }
234
+ }
235
+
236
+ // LLVM can figure out the input types on it's own, so we take a shortcut here.
237
+ unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
238
+ }
239
+
138
240
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
139
241
/// function with expected naming and calling conventions[^1] which will be
140
242
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -197,17 +299,9 @@ fn generate_enzyme_call<'ll>(
197
299
// }
198
300
// ```
199
301
unsafe {
200
- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
201
- // arguments. We do however need to declare them with their correct return type.
202
- // We already figured the correct return type out in our frontend, when generating the outer_fn,
203
- // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet.
204
- let fn_ty = llvm:: LLVMGlobalGetValueType ( outer_fn) ;
205
- let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
206
-
207
- // LLVM can figure out the input types on it's own, so we take a shortcut here.
208
- let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
302
+ let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
209
303
210
- //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
304
+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
211
305
// think a bit more about what should go here.
212
306
let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
213
307
let ad_fn = declare_simple_fn (
@@ -240,14 +334,27 @@ fn generate_enzyme_call<'ll>(
240
334
if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
241
335
args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
242
336
}
337
+ if attrs. width > 1 {
338
+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
339
+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
340
+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
341
+ }
243
342
343
+ let has_sret = has_sret ( outer_fn) ;
244
344
let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
245
- match_args_from_caller_to_enzyme ( & cx, & mut args, & attrs. input_activity , & outer_args) ;
345
+ match_args_from_caller_to_enzyme (
346
+ & cx,
347
+ attrs. width ,
348
+ & mut args,
349
+ & attrs. input_activity ,
350
+ & outer_args,
351
+ has_sret,
352
+ ) ;
246
353
247
354
let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
248
355
249
356
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
250
- // metadata attachted to it, but we just created this code oota. Given that the
357
+ // metadata attached to it, but we just created this code oota. Given that the
251
358
// differentiated function already has partly confusing metadata, and given that this
252
359
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
253
360
// dummy code which we inserted at a higher level.
@@ -268,7 +375,22 @@ fn generate_enzyme_call<'ll>(
268
375
// Now that we copied the metadata, get rid of dummy code.
269
376
llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
270
377
271
- if cx. val_ty ( call) == cx. type_void ( ) {
378
+ if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
379
+ if has_sret {
380
+ // This is what we already have in our outer_fn (shortened):
381
+ // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
382
+ // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
383
+ // <Here we are, we want to add the following two lines>
384
+ // store [4 x double] %7, ptr %0, align 8
385
+ // ret void
386
+ // }
387
+
388
+ // now store the result of the enzyme call into the sret pointer.
389
+ let sret_ptr = outer_args[ 0 ] ;
390
+ let call_ty = cx. val_ty ( call) ;
391
+ assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
392
+ llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
393
+ }
272
394
builder. ret_void ( ) ;
273
395
} else {
274
396
builder. ret ( call) ;
@@ -300,8 +422,7 @@ pub(crate) fn differentiate<'ll>(
300
422
if !diff_items. is_empty ( )
301
423
&& !cgcx. opts . unstable_opts . autodiff . contains ( & rustc_session:: config:: AutoDiff :: Enable )
302
424
{
303
- let dcx = cgcx. create_dcx ( ) ;
304
- return Err ( dcx. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
425
+ return Err ( diag_handler. handle ( ) . emit_almost_fatal ( AutoDiffWithoutEnable ) ) ;
305
426
}
306
427
307
428
// Before dumping the module, we want all the TypeTrees to become part of the module.
0 commit comments