Skip to content

Commit 3c5fc35

Browse files
committed
fix: try to improve type inference
1 parent b37c944 commit 3c5fc35

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

src/Evaluate.jl

+7-9
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization
1111
import ..ValueInterfaceModule: is_valid, is_valid_array
1212

1313
# Overloaded by SpecialOperators.jl:
14-
function any_special_operators(_)
15-
return false
16-
end
14+
function any_special_operators end
1715
function special_operator end
1816
function deg2_eval_special end
1917
function deg1_eval_special end
@@ -226,7 +224,7 @@ function eval_tree_array(
226224
"Bumper and LoopVectorization features are only compatible with numeric element types",
227225
)
228226
end
229-
if any_special_operators(typeof(operators))
227+
if any_special_operators(operators)
230228
cX = copy(cX)
231229
# TODO: This is dangerous if the element type is mutable
232230
end
@@ -338,7 +336,6 @@ end
338336
eval_options::EvalOptions,
339337
) where {T}
340338
nbin = get_nbin(operators)
341-
special_operators = any_special_operators(operators)
342339
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
343340
if long_compilation_time
344341
return quote
@@ -370,7 +367,7 @@ end
370367
@return_on_nonfinite_array(eval_options, result_l.x)
371368
# op(x, y), where y is a constant or variable but x is not.
372369
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
373-
elseif !$(special_operators) && tree.l.degree == 0
370+
elseif !any_special_operators(operators) && tree.l.degree == 0
374371
# This branch changes the execution order, so we cannot
375372
# use this branch when special operators are present.
376373
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
@@ -400,7 +397,6 @@ end
400397
eval_options::EvalOptions,
401398
) where {T}
402399
nuna = get_nuna(operators)
403-
special_operators = any_special_operators(operators)
404400
long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
405401
if long_compilation_time
406402
return quote
@@ -422,7 +418,7 @@ end
422418
i -> let op = operators.unaops[i]
423419
if special_operator(op)
424420
deg1_eval_special(tree, cX, operators, op, eval_options)
425-
elseif !$(special_operators) &&
421+
elseif !any_special_operators(operators) &&
426422
tree.l.degree == 2 &&
427423
tree.l.l.degree == 0 &&
428424
tree.l.r.degree == 0
@@ -431,7 +427,9 @@ end
431427
dispatch_deg1_l2_ll0_lr0_eval(
432428
tree, cX, op, l_op_idx, operators.binops, eval_options
433429
)
434-
elseif !$(special_operators) && tree.l.degree == 1 && tree.l.l.degree == 0
430+
elseif !any_special_operators(operators) &&
431+
tree.l.degree == 1 &&
432+
tree.l.l.degree == 0
435433
# op(op2(x)), where x is a constant or variable.
436434
l_op_idx = tree.l.op
437435
dispatch_deg1_l1_ll0_eval(

src/SpecialOperators.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@ import ..EvaluateModule:
1010
special_operator, deg2_eval_special, deg1_eval_special, any_special_operators
1111
import ..StringsModule: get_op_name
1212

13+
14+
# Use this to customize evaluation behavior for operators:
15+
@inline special_operator(::Type{F}) where {F} = false
16+
@inline special_operator(::F) where {F} = special_operator(F)
17+
1318
@generated function any_special_operators(
1419
::Union{O,Type{O}}
1520
) where {B,U,O<:OperatorEnum{B,U}}
1621
return any(special_operator, B.types) || any(special_operator, U.types)
1722
end
1823

19-
# Use this to customize evaluation behavior for operators:
20-
@inline special_operator(::Type{F}) where {F} = false
21-
@inline special_operator(::F) where {F} = special_operator(F)
22-
2324
Base.@kwdef struct AssignOperator <: Function
2425
target_register::Int
2526
end

0 commit comments

Comments
 (0)