@@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
3
3
using LoopVectorization: @turbo
4
4
using DynamicExpressions: AbstractExpressionNode
5
5
using DynamicExpressions. UtilsModule: ResultOk, fill_similar
6
- using DynamicExpressions. EvaluateModule: @return_on_check
6
+ using DynamicExpressions. EvaluateModule: @return_on_nonfinite_val , EvalOptions
7
7
import DynamicExpressions. EvaluateModule:
8
8
deg1_eval,
9
9
deg2_eval,
@@ -18,7 +18,10 @@ import DynamicExpressions.ExtensionInterfaceModule:
18
18
_is_loopvectorization_loaded (:: Int ) = true
19
19
20
20
function deg2_eval (
21
- cumulator_l:: AbstractVector{T} , cumulator_r:: AbstractVector{T} , op:: F , :: Val{true}
21
+ cumulator_l:: AbstractVector{T} ,
22
+ cumulator_r:: AbstractVector{T} ,
23
+ op:: F ,
24
+ :: EvalOptions{true} ,
22
25
):: ResultOk where {T<: Number ,F}
23
26
@turbo for j in eachindex (cumulator_l)
24
27
x = op (cumulator_l[j], cumulator_r[j])
@@ -28,7 +31,7 @@ function deg2_eval(
28
31
end
29
32
30
33
function deg1_eval (
31
- cumulator:: AbstractVector{T} , op:: F , :: Val {true}
34
+ cumulator:: AbstractVector{T} , op:: F , :: EvalOptions {true}
32
35
):: ResultOk where {T<: Number ,F}
33
36
@turbo for j in eachindex (cumulator)
34
37
x = op (cumulator[j])
@@ -38,21 +41,25 @@ function deg1_eval(
38
41
end
39
42
40
43
function deg1_l2_ll0_lr0_eval (
41
- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{true}
44
+ tree:: AbstractExpressionNode{T} ,
45
+ cX:: AbstractMatrix{T} ,
46
+ op:: F ,
47
+ op_l:: F2 ,
48
+ eval_options:: EvalOptions{true} ,
42
49
) where {T<: Number ,F,F2}
43
50
if tree. l. l. constant && tree. l. r. constant
44
51
val_ll = tree. l. l. val
45
52
val_lr = tree. l. r. val
46
- @return_on_check val_ll cX
47
- @return_on_check val_lr cX
53
+ @return_on_nonfinite_val (eval_options, val_ll, cX)
54
+ @return_on_nonfinite_val (eval_options, val_lr, cX)
48
55
x_l = op_l (val_ll, val_lr):: T
49
- @return_on_check x_l cX
56
+ @return_on_nonfinite_val (eval_options, x_l, cX)
50
57
x = op (x_l):: T
51
- @return_on_check x cX
58
+ @return_on_nonfinite_val (eval_options, x, cX)
52
59
return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
53
60
elseif tree. l. l. constant
54
61
val_ll = tree. l. l. val
55
- @return_on_check val_ll cX
62
+ @return_on_nonfinite_val (eval_options, val_ll, cX)
56
63
feature_lr = tree. l. r. feature
57
64
cumulator = similar (cX, axes (cX, 2 ))
58
65
@turbo for j in axes (cX, 2 )
@@ -64,7 +71,7 @@ function deg1_l2_ll0_lr0_eval(
64
71
elseif tree. l. r. constant
65
72
feature_ll = tree. l. l. feature
66
73
val_lr = tree. l. r. val
67
- @return_on_check val_lr cX
74
+ @return_on_nonfinite_val (eval_options, val_lr, cX)
68
75
cumulator = similar (cX, axes (cX, 2 ))
69
76
@turbo for j in axes (cX, 2 )
70
77
x_l = op_l (cX[feature_ll, j], val_lr)
@@ -86,15 +93,19 @@ function deg1_l2_ll0_lr0_eval(
86
93
end
87
94
88
95
function deg1_l1_ll0_eval (
89
- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{true}
96
+ tree:: AbstractExpressionNode{T} ,
97
+ cX:: AbstractMatrix{T} ,
98
+ op:: F ,
99
+ op_l:: F2 ,
100
+ eval_options:: EvalOptions{true} ,
90
101
) where {T<: Number ,F,F2}
91
102
if tree. l. l. constant
92
103
val_ll = tree. l. l. val
93
- @return_on_check val_ll cX
104
+ @return_on_nonfinite_val (eval_options, val_ll, cX)
94
105
x_l = op_l (val_ll):: T
95
- @return_on_check x_l cX
106
+ @return_on_nonfinite_val (eval_options, x_l, cX)
96
107
x = op (x_l):: T
97
- @return_on_check x cX
108
+ @return_on_nonfinite_val (eval_options, x, cX)
98
109
return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
99
110
else
100
111
feature_ll = tree. l. l. feature
@@ -109,20 +120,23 @@ function deg1_l1_ll0_eval(
109
120
end
110
121
111
122
function deg2_l0_r0_eval (
112
- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , op:: F , :: Val{true}
123
+ tree:: AbstractExpressionNode{T} ,
124
+ cX:: AbstractMatrix{T} ,
125
+ op:: F ,
126
+ eval_options:: EvalOptions{true} ,
113
127
) where {T<: Number ,F}
114
128
if tree. l. constant && tree. r. constant
115
129
val_l = tree. l. val
116
- @return_on_check val_l cX
130
+ @return_on_nonfinite_val (eval_options, val_l, cX)
117
131
val_r = tree. r. val
118
- @return_on_check val_r cX
132
+ @return_on_nonfinite_val (eval_options, val_r, cX)
119
133
x = op (val_l, val_r):: T
120
- @return_on_check x cX
134
+ @return_on_nonfinite_val (eval_options, x, cX)
121
135
return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
122
136
elseif tree. l. constant
123
137
cumulator = similar (cX, axes (cX, 2 ))
124
138
val_l = tree. l. val
125
- @return_on_check val_l cX
139
+ @return_on_nonfinite_val (eval_options, val_l, cX)
126
140
feature_r = tree. r. feature
127
141
@turbo for j in axes (cX, 2 )
128
142
x = op (val_l, cX[feature_r, j])
@@ -133,7 +147,7 @@ function deg2_l0_r0_eval(
133
147
cumulator = similar (cX, axes (cX, 2 ))
134
148
feature_l = tree. l. feature
135
149
val_r = tree. r. val
136
- @return_on_check val_r cX
150
+ @return_on_nonfinite_val (eval_options, val_r, cX)
137
151
@turbo for j in axes (cX, 2 )
138
152
x = op (cX[feature_l, j], val_r)
139
153
cumulator[j] = x
@@ -157,11 +171,11 @@ function deg2_l0_eval(
157
171
cumulator:: AbstractVector{T} ,
158
172
cX:: AbstractArray{T} ,
159
173
op:: F ,
160
- :: Val {true} ,
174
+ eval_options :: EvalOptions {true} ,
161
175
) where {T<: Number ,F}
162
176
if tree. l. constant
163
177
val = tree. l. val
164
- @return_on_check val cX
178
+ @return_on_nonfinite_val (eval_options, val, cX)
165
179
@turbo for j in eachindex (cumulator)
166
180
x = op (val, cumulator[j])
167
181
cumulator[j] = x
@@ -182,11 +196,11 @@ function deg2_r0_eval(
182
196
cumulator:: AbstractVector{T} ,
183
197
cX:: AbstractArray{T} ,
184
198
op:: F ,
185
- :: Val {true} ,
199
+ eval_options :: EvalOptions {true} ,
186
200
) where {T<: Number ,F}
187
201
if tree. r. constant
188
202
val = tree. r. val
189
- @return_on_check val cX
203
+ @return_on_nonfinite_val (eval_options, val, cX)
190
204
@turbo for j in eachindex (cumulator)
191
205
x = op (cumulator[j], val)
192
206
cumulator[j] = x
@@ -203,11 +217,15 @@ function deg2_r0_eval(
203
217
end
204
218
205
219
# # Interface with Bumper.jl
206
- function bumper_kern1! (op:: F , cumulator, :: Val{true} ) where {F}
220
+ function bumper_kern1! (
221
+ op:: F , cumulator, :: EvalOptions{true,true,early_exit}
222
+ ) where {F,early_exit}
207
223
@turbo @. cumulator = op (cumulator)
208
224
return cumulator
209
225
end
210
- function bumper_kern2! (op:: F , cumulator1, cumulator2, :: Val{true} ) where {F}
226
+ function bumper_kern2! (
227
+ op:: F , cumulator1, cumulator2, :: EvalOptions{true,true,early_exit}
228
+ ) where {F,early_exit}
211
229
@turbo @. cumulator1 = op (cumulator1, cumulator2)
212
230
return cumulator1
213
231
end
0 commit comments