Skip to content

Commit 963b3ce

Browse files
committed
update batching test
1 parent c733a11 commit 963b3ce

File tree

1 file changed

+65
-28
lines changed

1 file changed

+65
-28
lines changed

tests/codegen/autodiffv.rs

+65-28
Original file line numberDiff line numberDiff line change
@@ -7,45 +7,82 @@
77
use std::autodiff::autodiff;
88

99
#[autodiff(d_square3, Forward, Dual, DualOnly)]
10-
#[no_mangle]
11-
fn squaref(x: &f32) -> f32 {
12-
2.0 * x * x
13-
}
14-
1510
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
16-
#[autodiff(d_square, Forward, 4, Dual, Dual)]
11+
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
1712
#[no_mangle]
1813
fn square(x: &f32) -> f32 {
1914
x * x
2015
}
2116

22-
// CHECK:define internal fastcc void @diffe4square([4 x ptr] %"x'"
23-
// CHECK-NEXT:invertstart:
24-
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
25-
// CHECK-NEXT: %1 = load double, ptr %0, align 8, !alias.scope !15950, !noalias !15953
26-
// CHECK-NEXT: %2 = fadd fast double %1, 6.000000e+00
27-
// CHECK-NEXT: store double %2, ptr %0, align 8, !alias.scope !15950, !noalias !15953
28-
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 1
29-
// CHECK-NEXT: %4 = load double, ptr %3, align 8, !alias.scope !15958, !noalias !15959
30-
// CHECK-NEXT: %5 = fadd fast double %4, 6.000000e+00
31-
// CHECK-NEXT: store double %5, ptr %3, align 8, !alias.scope !15958, !noalias !15959
32-
// CHECK-NEXT: %6 = extractvalue [4 x ptr] %"x'", 2
33-
// CHECK-NEXT: %7 = load double, ptr %6, align 8, !alias.scope !15960, !noalias !15961
34-
// CHECK-NEXT: %8 = fadd fast double %7, 6.000000e+00
35-
// CHECK-NEXT: store double %8, ptr %6, align 8, !alias.scope !15960, !noalias !15961
36-
// CHECK-NEXT: %9 = extractvalue [4 x ptr] %"x'", 3
37-
// CHECK-NEXT: %10 = load double, ptr %9, align 8, !alias.scope !15962, !noalias !15963
38-
// CHECK-NEXT: %11 = fadd fast double %10, 6.000000e+00
39-
// CHECK-NEXT: store double %11, ptr %9, align 8, !alias.scope !15962, !noalias !15963
40-
// CHECK-NEXT: ret void
41-
// CHECK-NEXT:}
17+
// d_sqaure2
18+
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
19+
// CHECK-NEXT: start:
20+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
21+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !38, !noalias !39
22+
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
23+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !40, !noalias !41
24+
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
25+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !42, !noalias !43
26+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
27+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !44, !noalias !45
28+
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
29+
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
30+
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
31+
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
32+
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
33+
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
34+
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
35+
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
36+
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
37+
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
38+
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
39+
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
40+
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
41+
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
42+
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
43+
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
44+
// CHECK-NEXT: ret [4 x float] %19
45+
// CHECK-NEXT: }
46+
47+
// d_square3, the extra float is the original return value (x * x)
48+
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
49+
// CHECK-NEXT: start:
50+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
51+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !46, !noalias !47
52+
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
53+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !48, !noalias !49
54+
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
55+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !50, !noalias !51
56+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
57+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !52, !noalias !53
58+
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
59+
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
60+
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
61+
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
62+
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
63+
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
64+
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
65+
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
66+
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
67+
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
68+
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
69+
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
70+
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
71+
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
72+
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
73+
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
74+
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
75+
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
76+
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
77+
// CHECK-NEXT: ret { float, [4 x float] } %21
78+
// CHECK-NEXT: }
4279

4380
fn main() {
4481
let x = std::hint::black_box(3.0);
4582
let output = square(&x);
4683
dbg!(&output);
4784
assert_eq!(9.0, output);
48-
dbg!(squaref(&x));
85+
dbg!(square(&x));
4986

5087
let mut df_dx1 = 1.0;
5188
let mut df_dx2 = 2.0;
@@ -54,7 +91,7 @@ fn main() {
5491
let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
5592
dbg!(o1, o2, o3, o4);
5693
let [output2, o1, o2, o3, o4] =
57-
d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
94+
d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
5895
dbg!(o1, o2, o3, o4);
5996
assert_eq!(output, output2);
6097
assert!((6.0 - o1).abs() < 1e-10);

0 commit comments

Comments
 (0)