7
7
use std:: autodiff:: autodiff;
8
8
9
9
#[ autodiff( d_square3, Forward , Dual , DualOnly ) ]
10
- #[ no_mangle]
11
- fn squaref ( x : & f32 ) -> f32 {
12
- 2.0 * x * x
13
- }
14
-
15
10
#[ autodiff( d_square2, Forward , 4 , Dual , DualOnly ) ]
16
- #[ autodiff( d_square , Forward , 4 , Dual , Dual ) ]
11
+ #[ autodiff( d_square1 , Forward , 4 , Dual , Dual ) ]
17
12
#[ no_mangle]
18
13
fn square ( x : & f32 ) -> f32 {
19
14
x * x
20
15
}
21
16
22
- // CHECK:define internal fastcc void @diffe4square([4 x ptr] %"x'"
23
- // CHECK-NEXT:invertstart:
24
- // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
25
- // CHECK-NEXT: %1 = load double, ptr %0, align 8, !alias.scope !15950, !noalias !15953
26
- // CHECK-NEXT: %2 = fadd fast double %1, 6.000000e+00
27
- // CHECK-NEXT: store double %2, ptr %0, align 8, !alias.scope !15950, !noalias !15953
28
- // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 1
29
- // CHECK-NEXT: %4 = load double, ptr %3, align 8, !alias.scope !15958, !noalias !15959
30
- // CHECK-NEXT: %5 = fadd fast double %4, 6.000000e+00
31
- // CHECK-NEXT: store double %5, ptr %3, align 8, !alias.scope !15958, !noalias !15959
32
- // CHECK-NEXT: %6 = extractvalue [4 x ptr] %"x'", 2
33
- // CHECK-NEXT: %7 = load double, ptr %6, align 8, !alias.scope !15960, !noalias !15961
34
- // CHECK-NEXT: %8 = fadd fast double %7, 6.000000e+00
35
- // CHECK-NEXT: store double %8, ptr %6, align 8, !alias.scope !15960, !noalias !15961
36
- // CHECK-NEXT: %9 = extractvalue [4 x ptr] %"x'", 3
37
- // CHECK-NEXT: %10 = load double, ptr %9, align 8, !alias.scope !15962, !noalias !15963
38
- // CHECK-NEXT: %11 = fadd fast double %10, 6.000000e+00
39
- // CHECK-NEXT: store double %11, ptr %9, align 8, !alias.scope !15962, !noalias !15963
40
- // CHECK-NEXT: ret void
41
- // CHECK-NEXT:}
17
+ // d_sqaure2
18
+ // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
19
+ // CHECK-NEXT: start:
20
+ // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
21
+ // CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !38, !noalias !39
22
+ // CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
23
+ // CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !40, !noalias !41
24
+ // CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
25
+ // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !42, !noalias !43
26
+ // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
27
+ // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !44, !noalias !45
28
+ // CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
29
+ // CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
30
+ // CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
31
+ // CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
32
+ // CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
33
+ // CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
34
+ // CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
35
+ // CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
36
+ // CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
37
+ // CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
38
+ // CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
39
+ // CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
40
+ // CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
41
+ // CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
42
+ // CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
43
+ // CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
44
+ // CHECK-NEXT: ret [4 x float] %19
45
+ // CHECK-NEXT: }
46
+
47
+ // d_square3, the extra float is the original return value (x * x)
48
+ // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
49
+ // CHECK-NEXT: start:
50
+ // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
51
+ // CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !46, !noalias !47
52
+ // CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
53
+ // CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !48, !noalias !49
54
+ // CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
55
+ // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !50, !noalias !51
56
+ // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
57
+ // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !52, !noalias !53
58
+ // CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
59
+ // CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
60
+ // CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
61
+ // CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
62
+ // CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
63
+ // CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
64
+ // CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
65
+ // CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
66
+ // CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
67
+ // CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
68
+ // CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
69
+ // CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
70
+ // CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
71
+ // CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
72
+ // CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
73
+ // CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
74
+ // CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
75
+ // CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
76
+ // CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
77
+ // CHECK-NEXT: ret { float, [4 x float] } %21
78
+ // CHECK-NEXT: }
42
79
43
80
fn main ( ) {
44
81
let x = std:: hint:: black_box ( 3.0 ) ;
45
82
let output = square ( & x) ;
46
83
dbg ! ( & output) ;
47
84
assert_eq ! ( 9.0 , output) ;
48
- dbg ! ( squaref ( & x) ) ;
85
+ dbg ! ( square ( & x) ) ;
49
86
50
87
let mut df_dx1 = 1.0 ;
51
88
let mut df_dx2 = 2.0 ;
@@ -54,7 +91,7 @@ fn main() {
54
91
let [ o1, o2, o3, o4] = d_square2 ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
55
92
dbg ! ( o1, o2, o3, o4) ;
56
93
let [ output2, o1, o2, o3, o4] =
57
- d_square ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
94
+ d_square1 ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
58
95
dbg ! ( o1, o2, o3, o4) ;
59
96
assert_eq ! ( output, output2) ;
60
97
assert ! ( ( 6.0 - o1) . abs( ) < 1e-10 ) ;
0 commit comments