Skip to content

Commit 17a4a24

Browse files
committed
fix: incorporate @return_on_nonfinite_val in LoopVectorization extension
1 parent 09b7a3d commit 17a4a24

File tree

1 file changed

+25
-22
lines changed

1 file changed

+25
-22
lines changed

ext/DynamicExpressionsLoopVectorizationExt.jl

+25-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
33
using LoopVectorization: @turbo
44
using DynamicExpressions: AbstractExpressionNode
55
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
6-
using DynamicExpressions.EvaluateModule: @return_on_check, EvalOptions
6+
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
77
import DynamicExpressions.EvaluateModule:
88
deg1_eval,
99
deg2_eval,
@@ -45,21 +45,21 @@ function deg1_l2_ll0_lr0_eval(
4545
cX::AbstractMatrix{T},
4646
op::F,
4747
op_l::F2,
48-
::EvalOptions{true},
48+
eval_options::EvalOptions{true},
4949
) where {T<:Number,F,F2}
5050
if tree.l.l.constant && tree.l.r.constant
5151
val_ll = tree.l.l.val
5252
val_lr = tree.l.r.val
53-
@return_on_check val_ll cX
54-
@return_on_check val_lr cX
53+
@return_on_nonfinite_val(eval_options, val_ll, cX)
54+
@return_on_nonfinite_val(eval_options, val_lr, cX)
5555
x_l = op_l(val_ll, val_lr)::T
56-
@return_on_check x_l cX
56+
@return_on_nonfinite_val(eval_options, x_l, cX)
5757
x = op(x_l)::T
58-
@return_on_check x cX
58+
@return_on_nonfinite_val(eval_options, x, cX)
5959
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
6060
elseif tree.l.l.constant
6161
val_ll = tree.l.l.val
62-
@return_on_check val_ll cX
62+
@return_on_nonfinite_val(eval_options, val_ll, cX)
6363
feature_lr = tree.l.r.feature
6464
cumulator = similar(cX, axes(cX, 2))
6565
@turbo for j in axes(cX, 2)
@@ -71,7 +71,7 @@ function deg1_l2_ll0_lr0_eval(
7171
elseif tree.l.r.constant
7272
feature_ll = tree.l.l.feature
7373
val_lr = tree.l.r.val
74-
@return_on_check val_lr cX
74+
@return_on_nonfinite_val(eval_options, val_lr, cX)
7575
cumulator = similar(cX, axes(cX, 2))
7676
@turbo for j in axes(cX, 2)
7777
x_l = op_l(cX[feature_ll, j], val_lr)
@@ -97,15 +97,15 @@ function deg1_l1_ll0_eval(
9797
cX::AbstractMatrix{T},
9898
op::F,
9999
op_l::F2,
100-
::EvalOptions{true},
100+
eval_options::EvalOptions{true},
101101
) where {T<:Number,F,F2}
102102
if tree.l.l.constant
103103
val_ll = tree.l.l.val
104-
@return_on_check val_ll cX
104+
@return_on_nonfinite_val(eval_options, val_ll, cX)
105105
x_l = op_l(val_ll)::T
106-
@return_on_check x_l cX
106+
@return_on_nonfinite_val(eval_options, x_l, cX)
107107
x = op(x_l)::T
108-
@return_on_check x cX
108+
@return_on_nonfinite_val(eval_options, x, cX)
109109
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
110110
else
111111
feature_ll = tree.l.l.feature
@@ -120,20 +120,23 @@ function deg1_l1_ll0_eval(
120120
end
121121

122122
function deg2_l0_r0_eval(
123-
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvalOptions{true}
123+
tree::AbstractExpressionNode{T},
124+
cX::AbstractMatrix{T},
125+
op::F,
126+
eval_options::EvalOptions{true},
124127
) where {T<:Number,F}
125128
if tree.l.constant && tree.r.constant
126129
val_l = tree.l.val
127-
@return_on_check val_l cX
130+
@return_on_nonfinite_val(eval_options, val_l, cX)
128131
val_r = tree.r.val
129-
@return_on_check val_r cX
132+
@return_on_nonfinite_val(eval_options, val_r, cX)
130133
x = op(val_l, val_r)::T
131-
@return_on_check x cX
134+
@return_on_nonfinite_val(eval_options, x, cX)
132135
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
133136
elseif tree.l.constant
134137
cumulator = similar(cX, axes(cX, 2))
135138
val_l = tree.l.val
136-
@return_on_check val_l cX
139+
@return_on_nonfinite_val(eval_options, val_l, cX)
137140
feature_r = tree.r.feature
138141
@turbo for j in axes(cX, 2)
139142
x = op(val_l, cX[feature_r, j])
@@ -144,7 +147,7 @@ function deg2_l0_r0_eval(
144147
cumulator = similar(cX, axes(cX, 2))
145148
feature_l = tree.l.feature
146149
val_r = tree.r.val
147-
@return_on_check val_r cX
150+
@return_on_nonfinite_val(eval_options, val_r, cX)
148151
@turbo for j in axes(cX, 2)
149152
x = op(cX[feature_l, j], val_r)
150153
cumulator[j] = x
@@ -168,11 +171,11 @@ function deg2_l0_eval(
168171
cumulator::AbstractVector{T},
169172
cX::AbstractArray{T},
170173
op::F,
171-
::EvalOptions{true},
174+
eval_options::EvalOptions{true},
172175
) where {T<:Number,F}
173176
if tree.l.constant
174177
val = tree.l.val
175-
@return_on_check val cX
178+
@return_on_nonfinite_val(eval_options, val, cX)
176179
@turbo for j in eachindex(cumulator)
177180
x = op(val, cumulator[j])
178181
cumulator[j] = x
@@ -193,11 +196,11 @@ function deg2_r0_eval(
193196
cumulator::AbstractVector{T},
194197
cX::AbstractArray{T},
195198
op::F,
196-
::EvalOptions{true},
199+
eval_options::EvalOptions{true},
197200
) where {T<:Number,F}
198201
if tree.r.constant
199202
val = tree.r.val
200-
@return_on_check val cX
203+
@return_on_nonfinite_val(eval_options, val, cX)
201204
@turbo for j in eachindex(cumulator)
202205
x = op(cumulator[j], val)
203206
cumulator[j] = x

0 commit comments

Comments
 (0)