Skip to content

Commit db39efc

Browse files
authored
Merge pull request #91 from nmheim/nh/early-exit
Add parameter to disable early exit of expression evaluation
2 parents f650e28 + 17a4a24 commit db39efc

14 files changed

+415
-199
lines changed

benchmark/benchmarks.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,18 @@ function benchmark_evaluation()
7474
extra_kws...
7575
)
7676
suite[T]["evaluation$(extra_key)"] = @benchmarkable(
77-
[eval_tree_array(tree, X, $operators; turbo=$turbo, $extra_kws...) for tree in trees],
77+
[eval_tree_array(tree, X, $operators; kws...) for tree in trees],
7878
setup=(
7979
X=randn(MersenneTwister(0), $T, 5, $n);
8080
treesize=20;
8181
ntrees=100;
82+
kws=$(
83+
if @isdefined(EvalOptions)
84+
(; eval_options=EvalOptions(; turbo=turbo, extra_kws...))
85+
else
86+
(; turbo, extra_kws...)
87+
end
88+
);
8289
trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees]
8390
)
8491
)

docs/src/eval.md

+17-6
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ Given an expression tree specified with a `Node` type, you may evaluate the expr
66
over an array of data with the following command:
77

88
```@docs
9-
eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
9+
eval_tree_array(
10+
tree::AbstractExpressionNode{T},
11+
cX::AbstractMatrix{T},
12+
operators::OperatorEnum;
13+
eval_options::Union{EvalOptions,Nothing}=nothing,
14+
) where {T}
1015
```
1116

12-
Assuming you are only using a single `OperatorEnum`, you can also use
13-
the following shorthand by using the expression as a function:
17+
You can also use the following shorthand by using the expression as a function:
1418

1519
```
16-
(tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false))
20+
(tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...)
1721
1822
Evaluate a binary tree (equation) over a given input data matrix. The
1923
operators contain all of the operators used. This function fuses doublets
@@ -23,8 +27,7 @@ and triplets of operations for lower memory usage.
2327
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
2428
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
2529
- `operators::OperatorEnum`: The operators used in the tree.
26-
- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation.
27-
- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation.
30+
- `kws...`: Passed to [`eval_tree_array`](@ref).
2831
2932
# Returns
3033
- `output::AbstractVector{T}`: the result, which is a 1D array.
@@ -53,6 +56,14 @@ It also re-defines `print`, `show`, and the various operators, to work with the
5356
Thus, if you define an expression with one `OperatorEnum`, and then try to
5457
evaluate it or print it with a different `OperatorEnum`, you will get undefined behavior!
5558

59+
For safer behavior, you should use [`Expression`](@ref) objects.
60+
61+
Evaluation options are specified using `EvalOptions`:
62+
63+
```@docs
64+
EvalOptions
65+
```
66+
5667
You can also work with arbitrary types, by defining a `GenericOperatorEnum` instead.
5768
The notation is the same for `eval_tree_array`, though it will return `nothing`
5869
when it can't find a method, and not do any NaN checks:

ext/DynamicExpressionsBumperExt.jl

+25-19
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DynamicExpressionsBumperExt
22

33
using Bumper: @no_escape, @alloc
44
using DynamicExpressions:
5-
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array
5+
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions
66
using DynamicExpressions.UtilsModule: ResultOk, counttuple
77

88
import DynamicExpressions.ExtensionInterfaceModule:
@@ -12,8 +12,8 @@ function bumper_eval_tree_array(
1212
tree::AbstractExpressionNode{T},
1313
cX::AbstractMatrix{T},
1414
operators::OperatorEnum,
15-
::Val{turbo},
16-
) where {T,turbo}
15+
eval_options::EvalOptions{turbo,true,early_exit},
16+
) where {T,turbo,early_exit}
1717
result = similar(cX, axes(cX, 2))
1818
n = size(cX, 2)
1919
all_ok = Ref(false)
@@ -26,7 +26,7 @@ function bumper_eval_tree_array(
2626
ok = if leaf_node.constant
2727
v = leaf_node.val
2828
ar .= v
29-
isfinite(v)
29+
early_exit ? isfinite(v) : true
3030
else
3131
ar .= view(cX, leaf_node.feature, :)
3232
true
@@ -38,7 +38,7 @@ function bumper_eval_tree_array(
3838
# In the evaluation kernel, we combine the branch nodes
3939
# with the arrays created by the leaf nodes:
4040
((args::Vararg{Any,M}) where {M}) ->
41-
dispatch_kerns!(operators, args..., Val(turbo)),
41+
dispatch_kerns!(operators, args..., eval_options),
4242
tree;
4343
break_sharing=Val(true),
4444
)
@@ -49,55 +49,61 @@ function bumper_eval_tree_array(
4949
return (result, all_ok[])
5050
end
5151

52-
function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo}
52+
function dispatch_kerns!(
53+
operators, branch_node, cumulator, eval_options::EvalOptions{<:Any,true,early_exit}
54+
) where {early_exit}
5355
cumulator.ok || return cumulator
5456

55-
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
56-
return ResultOk(out, is_valid_array(out))
57+
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, eval_options)
58+
return ResultOk(out, early_exit ? is_valid_array(out) : true)
5759
end
5860
function dispatch_kerns!(
59-
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
60-
) where {turbo}
61+
operators,
62+
branch_node,
63+
cumulator1,
64+
cumulator2,
65+
eval_options::EvalOptions{<:Any,true,early_exit},
66+
) where {early_exit}
6167
cumulator1.ok || return cumulator1
6268
cumulator2.ok || return cumulator2
6369

6470
out = dispatch_kern2!(
65-
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
71+
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, eval_options
6672
)
67-
return ResultOk(out, is_valid_array(out))
73+
return ResultOk(out, early_exit ? is_valid_array(out) : true)
6874
end
6975

70-
@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}
76+
@generated function dispatch_kern1!(unaops, op_idx, cumulator, eval_options::EvalOptions)
7177
nuna = counttuple(unaops)
7278
quote
7379
Base.@nif(
7480
$nuna,
7581
i -> i == op_idx,
7682
i -> let op = unaops[i]
77-
return bumper_kern1!(op, cumulator, Val(turbo))
83+
return bumper_kern1!(op, cumulator, eval_options)
7884
end,
7985
)
8086
end
8187
end
8288
@generated function dispatch_kern2!(
83-
binops, op_idx, cumulator1, cumulator2, ::Val{turbo}
84-
) where {turbo}
89+
binops, op_idx, cumulator1, cumulator2, eval_options::EvalOptions
90+
)
8591
nbin = counttuple(binops)
8692
quote
8793
Base.@nif(
8894
$nbin,
8995
i -> i == op_idx,
9096
i -> let op = binops[i]
91-
return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo))
97+
return bumper_kern2!(op, cumulator1, cumulator2, eval_options)
9298
end,
9399
)
94100
end
95101
end
96-
function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F}
102+
function bumper_kern1!(op::F, cumulator, ::EvalOptions{false,true}) where {F}
97103
@. cumulator = op(cumulator)
98104
return cumulator
99105
end
100-
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F}
106+
function bumper_kern2!(op::F, cumulator1, cumulator2, ::EvalOptions{false,true}) where {F}
101107
@. cumulator1 = op(cumulator1, cumulator2)
102108
return cumulator1
103109
end

ext/DynamicExpressionsLoopVectorizationExt.jl

+44-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
33
using LoopVectorization: @turbo
44
using DynamicExpressions: AbstractExpressionNode
55
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
6-
using DynamicExpressions.EvaluateModule: @return_on_check
6+
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
77
import DynamicExpressions.EvaluateModule:
88
deg1_eval,
99
deg2_eval,
@@ -18,7 +18,10 @@ import DynamicExpressions.ExtensionInterfaceModule:
1818
_is_loopvectorization_loaded(::Int) = true
1919

2020
function deg2_eval(
21-
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{true}
21+
cumulator_l::AbstractVector{T},
22+
cumulator_r::AbstractVector{T},
23+
op::F,
24+
::EvalOptions{true},
2225
)::ResultOk where {T<:Number,F}
2326
@turbo for j in eachindex(cumulator_l)
2427
x = op(cumulator_l[j], cumulator_r[j])
@@ -28,7 +31,7 @@ function deg2_eval(
2831
end
2932

3033
function deg1_eval(
31-
cumulator::AbstractVector{T}, op::F, ::Val{true}
34+
cumulator::AbstractVector{T}, op::F, ::EvalOptions{true}
3235
)::ResultOk where {T<:Number,F}
3336
@turbo for j in eachindex(cumulator)
3437
x = op(cumulator[j])
@@ -38,21 +41,25 @@ function deg1_eval(
3841
end
3942

4043
function deg1_l2_ll0_lr0_eval(
41-
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true}
44+
tree::AbstractExpressionNode{T},
45+
cX::AbstractMatrix{T},
46+
op::F,
47+
op_l::F2,
48+
eval_options::EvalOptions{true},
4249
) where {T<:Number,F,F2}
4350
if tree.l.l.constant && tree.l.r.constant
4451
val_ll = tree.l.l.val
4552
val_lr = tree.l.r.val
46-
@return_on_check val_ll cX
47-
@return_on_check val_lr cX
53+
@return_on_nonfinite_val(eval_options, val_ll, cX)
54+
@return_on_nonfinite_val(eval_options, val_lr, cX)
4855
x_l = op_l(val_ll, val_lr)::T
49-
@return_on_check x_l cX
56+
@return_on_nonfinite_val(eval_options, x_l, cX)
5057
x = op(x_l)::T
51-
@return_on_check x cX
58+
@return_on_nonfinite_val(eval_options, x, cX)
5259
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
5360
elseif tree.l.l.constant
5461
val_ll = tree.l.l.val
55-
@return_on_check val_ll cX
62+
@return_on_nonfinite_val(eval_options, val_ll, cX)
5663
feature_lr = tree.l.r.feature
5764
cumulator = similar(cX, axes(cX, 2))
5865
@turbo for j in axes(cX, 2)
@@ -64,7 +71,7 @@ function deg1_l2_ll0_lr0_eval(
6471
elseif tree.l.r.constant
6572
feature_ll = tree.l.l.feature
6673
val_lr = tree.l.r.val
67-
@return_on_check val_lr cX
74+
@return_on_nonfinite_val(eval_options, val_lr, cX)
6875
cumulator = similar(cX, axes(cX, 2))
6976
@turbo for j in axes(cX, 2)
7077
x_l = op_l(cX[feature_ll, j], val_lr)
@@ -86,15 +93,19 @@ function deg1_l2_ll0_lr0_eval(
8693
end
8794

8895
function deg1_l1_ll0_eval(
89-
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true}
96+
tree::AbstractExpressionNode{T},
97+
cX::AbstractMatrix{T},
98+
op::F,
99+
op_l::F2,
100+
eval_options::EvalOptions{true},
90101
) where {T<:Number,F,F2}
91102
if tree.l.l.constant
92103
val_ll = tree.l.l.val
93-
@return_on_check val_ll cX
104+
@return_on_nonfinite_val(eval_options, val_ll, cX)
94105
x_l = op_l(val_ll)::T
95-
@return_on_check x_l cX
106+
@return_on_nonfinite_val(eval_options, x_l, cX)
96107
x = op(x_l)::T
97-
@return_on_check x cX
108+
@return_on_nonfinite_val(eval_options, x, cX)
98109
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
99110
else
100111
feature_ll = tree.l.l.feature
@@ -109,20 +120,23 @@ function deg1_l1_ll0_eval(
109120
end
110121

111122
function deg2_l0_r0_eval(
112-
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{true}
123+
tree::AbstractExpressionNode{T},
124+
cX::AbstractMatrix{T},
125+
op::F,
126+
eval_options::EvalOptions{true},
113127
) where {T<:Number,F}
114128
if tree.l.constant && tree.r.constant
115129
val_l = tree.l.val
116-
@return_on_check val_l cX
130+
@return_on_nonfinite_val(eval_options, val_l, cX)
117131
val_r = tree.r.val
118-
@return_on_check val_r cX
132+
@return_on_nonfinite_val(eval_options, val_r, cX)
119133
x = op(val_l, val_r)::T
120-
@return_on_check x cX
134+
@return_on_nonfinite_val(eval_options, x, cX)
121135
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
122136
elseif tree.l.constant
123137
cumulator = similar(cX, axes(cX, 2))
124138
val_l = tree.l.val
125-
@return_on_check val_l cX
139+
@return_on_nonfinite_val(eval_options, val_l, cX)
126140
feature_r = tree.r.feature
127141
@turbo for j in axes(cX, 2)
128142
x = op(val_l, cX[feature_r, j])
@@ -133,7 +147,7 @@ function deg2_l0_r0_eval(
133147
cumulator = similar(cX, axes(cX, 2))
134148
feature_l = tree.l.feature
135149
val_r = tree.r.val
136-
@return_on_check val_r cX
150+
@return_on_nonfinite_val(eval_options, val_r, cX)
137151
@turbo for j in axes(cX, 2)
138152
x = op(cX[feature_l, j], val_r)
139153
cumulator[j] = x
@@ -157,11 +171,11 @@ function deg2_l0_eval(
157171
cumulator::AbstractVector{T},
158172
cX::AbstractArray{T},
159173
op::F,
160-
::Val{true},
174+
eval_options::EvalOptions{true},
161175
) where {T<:Number,F}
162176
if tree.l.constant
163177
val = tree.l.val
164-
@return_on_check val cX
178+
@return_on_nonfinite_val(eval_options, val, cX)
165179
@turbo for j in eachindex(cumulator)
166180
x = op(val, cumulator[j])
167181
cumulator[j] = x
@@ -182,11 +196,11 @@ function deg2_r0_eval(
182196
cumulator::AbstractVector{T},
183197
cX::AbstractArray{T},
184198
op::F,
185-
::Val{true},
199+
eval_options::EvalOptions{true},
186200
) where {T<:Number,F}
187201
if tree.r.constant
188202
val = tree.r.val
189-
@return_on_check val cX
203+
@return_on_nonfinite_val(eval_options, val, cX)
190204
@turbo for j in eachindex(cumulator)
191205
x = op(cumulator[j], val)
192206
cumulator[j] = x
@@ -203,11 +217,15 @@ function deg2_r0_eval(
203217
end
204218

205219
## Interface with Bumper.jl
206-
function bumper_kern1!(op::F, cumulator, ::Val{true}) where {F}
220+
function bumper_kern1!(
221+
op::F, cumulator, ::EvalOptions{true,true,early_exit}
222+
) where {F,early_exit}
207223
@turbo @. cumulator = op(cumulator)
208224
return cumulator
209225
end
210-
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}) where {F}
226+
function bumper_kern2!(
227+
op::F, cumulator1, cumulator2, ::EvalOptions{true,true,early_exit}
228+
) where {F,early_exit}
211229
@turbo @. cumulator1 = op(cumulator1, cumulator2)
212230
return cumulator1
213231
end

src/DynamicExpressions.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ import .NodeModule:
7070
@reexport import .OperatorEnumModule: AbstractOperatorEnum
7171
@reexport import .OperatorEnumConstructionModule:
7272
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
73-
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
73+
@reexport import .EvaluateModule:
74+
eval_tree_array, differentiable_eval_tree_array, EvalOptions
7475
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
7576
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
7677
@reexport import .SimplifyModule: combine_operators, simplify_tree!

0 commit comments

Comments
 (0)