Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turing complete expressions #123

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using DispatchDoctor: @stable, @unstable
include("OperatorEnumConstruction.jl")
include("Expression.jl")
include("ExpressionAlgebra.jl")
include("SpecialOperators.jl")
include("Random.jl")
include("Parse.jl")
include("ParametricExpression.jl")
Expand Down Expand Up @@ -76,6 +77,7 @@ import .StringsModule: get_op_name
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
import .EvaluateModule: ArrayBuffer
@reexport import .SpecialOperatorsModule: AssignOperator, WhileOperator
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
41 changes: 34 additions & 7 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import ..NodeUtilsModule: is_constant
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded
import ..ValueInterfaceModule: is_valid, is_valid_array

# Overloaded by SpecialOperators.jl:
function any_special_operators end
function special_operator end
function deg2_eval_special end
function deg1_eval_special end

const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15

macro return_on_nonfinite_val(eval_options, val, X)
Expand Down Expand Up @@ -218,6 +224,10 @@ function eval_tree_array(
"Bumper and LoopVectorization features are only compatible with numeric element types",
)
end
if any_special_operators(operators)
cX = copy(cX)
# TODO: This is dangerous if the element type is mutable
end
if _eval_options.bumper isa Val{true}
return bumper_eval_tree_array(tree, cX, operators, _eval_options)
end
Expand Down Expand Up @@ -264,7 +274,7 @@ function _eval_tree_array(
# we can just return the constant result.
if tree.degree == 0
return deg0_eval(tree, cX, eval_options)
elseif is_constant(tree)
elseif !any_special_operators(operators) && is_constant(tree)
# Speed hack for constant trees.
const_result = dispatch_constant_tree(tree, operators)::ResultOk{T}
!const_result.ok &&
Expand Down Expand Up @@ -329,30 +339,37 @@ end
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
if long_compilation_time
return quote
op = operators.binops[op_idx]
special_operator(op) &&
return deg2_eval_special(tree, cX, operators, op, eval_options)
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
!result_l.ok && return result_l
@return_on_nonfinite_array(eval_options, result_l.x)
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
!result_r.ok && return result_r
@return_on_nonfinite_array(eval_options, result_r.x)
# op(x, y), for any x or y
deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options)
deg2_eval(result_l.x, result_r.x, op, eval_options)
end
end
return quote
return Base.Cartesian.@nif(
$nbin,
i -> i == op_idx,
i -> let op = operators.binops[i]
if tree.l.degree == 0 && tree.r.degree == 0
if special_operator(op)
deg2_eval_special(tree, cX, operators, op, eval_options)
elseif tree.l.degree == 0 && tree.r.degree == 0
deg2_l0_r0_eval(tree, cX, op, eval_options)
elseif tree.r.degree == 0
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
!result_l.ok && return result_l
@return_on_nonfinite_array(eval_options, result_l.x)
# op(x, y), where y is a constant or variable but x is not.
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
elseif tree.l.degree == 0
elseif !any_special_operators(operators) && tree.l.degree == 0
# This branch changes the execution order, so we cannot
# use this branch when special operators are present.
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
!result_r.ok && return result_r
@return_on_nonfinite_array(eval_options, result_r.x)
Expand Down Expand Up @@ -383,10 +400,13 @@ end
long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
if long_compilation_time
return quote
op = operators.unaops[op_idx]
special_operator(op) &&
return deg1_eval_special(tree, cX, operators, op, eval_options)
result = _eval_tree_array(tree.l, cX, operators, eval_options)
!result.ok && return result
@return_on_nonfinite_array(eval_options, result.x)
deg1_eval(result.x, operators.unaops[op_idx], eval_options)
deg1_eval(result.x, op, eval_options)
end
end
# This @nif lets us generate an if statement over choice of operator,
Expand All @@ -396,13 +416,20 @@ end
$nuna,
i -> i == op_idx,
i -> let op = operators.unaops[i]
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
if special_operator(op)
deg1_eval_special(tree, cX, operators, op, eval_options)
elseif !any_special_operators(operators) &&
tree.l.degree == 2 &&
tree.l.l.degree == 0 &&
tree.l.r.degree == 0
# op(op2(x, y)), where x, y, z are constants or variables.
l_op_idx = tree.l.op
dispatch_deg1_l2_ll0_lr0_eval(
tree, cX, op, l_op_idx, operators.binops, eval_options
)
elseif tree.l.degree == 1 && tree.l.l.degree == 0
elseif !any_special_operators(operators) &&
tree.l.degree == 1 &&
tree.l.l.degree == 0
# op(op2(x)), where x is a constant or variable.
l_op_idx = tree.l.op
dispatch_deg1_l1_ll0_eval(
Expand Down
18 changes: 15 additions & 3 deletions src/Simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
import ..OperatorEnumModule: AbstractOperatorEnum
import ..ValueInterfaceModule: is_valid
import ..EvaluateModule: any_special_operators

_una_op_kernel(f::F, l::T) where {F,T} = f(l)
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
Expand All @@ -19,6 +20,12 @@ combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
# This is only defined for `Node` as it is not possible for, e.g.,
# `GraphNode`.
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
# Skip simplification if special operators are in use
any_special_operators(operators) && return tree
return _combine_operators(tree, operators)
end

function _combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
# ((const + var) + const) => (const + var)
# ((const * var) * const) => (const * var)
Expand All @@ -28,10 +35,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
if tree.degree == 0
return tree
elseif tree.degree == 1
tree.l = combine_operators(tree.l, operators)
tree.l = _combine_operators(tree.l, operators)
elseif tree.degree == 2
tree.l = combine_operators(tree.l, operators)
tree.r = combine_operators(tree.r, operators)
tree.l = _combine_operators(tree.l, operators)
tree.r = _combine_operators(tree.r, operators)
end

top_level_constant =
Expand Down Expand Up @@ -123,6 +130,11 @@ end

# Simplify tree
function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum)
# Skip simplification if special operators are in use
if any_special_operators(operators)
return tree
end

return tree_mapreduce(
identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree);
)
Expand Down
84 changes: 84 additions & 0 deletions src/SpecialOperators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
module SpecialOperatorsModule

using ..OperatorEnumModule: OperatorEnum
using ..EvaluateModule:
_eval_tree_array, @return_on_nonfinite_array, deg2_eval, ResultOk, get_filled_array
using ..ExpressionModule: AbstractExpression
using ..ExpressionAlgebraModule: @declare_expression_operator

import ..EvaluateModule:
special_operator, deg2_eval_special, deg1_eval_special, any_special_operators
import ..StringsModule: get_op_name

# Use this to customize evaluation behavior for operators:
@inline special_operator(::Type{F}) where {F} = false
@inline special_operator(::F) where {F} = special_operator(F)

@generated function any_special_operators(
::Union{O,Type{O}}
) where {B,U,O<:OperatorEnum{B,U}}
return any(special_operator, B.types) || any(special_operator, U.types)
end

Base.@kwdef struct AssignOperator <: Function
target_register::Int
end
@declare_expression_operator((op::AssignOperator), 1)
@inline special_operator(::Type{AssignOperator}) = true
get_op_name(o::AssignOperator) = "ASSIGN_OP:{FEATURE_" * string(o.target_register) * "}"

function deg1_eval_special(tree, cX, operators, op::AssignOperator, eval_options)
result = _eval_tree_array(tree.l, cX, operators, eval_options)
!result.ok && return result
@return_on_nonfinite_array(eval_options, result.x)
target_register = op.target_register
@inbounds @simd for i in eachindex(axes(cX, 2))
cX[target_register, i] = result.x[i]
end
return result
end

Base.@kwdef struct WhileOperator <: Function
max_iters::Int = 100
end

@declare_expression_operator((op::WhileOperator), 2)
@inline special_operator(::Type{WhileOperator}) = true
get_op_name(o::WhileOperator) = "while"

# TODO: Need to void any instance of buffer when using while loop.
function deg2_eval_special(tree, cX, operators, op::WhileOperator, eval_options)
cond = tree.l
body = tree.r
mask = trues(size(cX, 2))
X = @view cX[:, mask]
# Initialize the result array for all columns
result_array = get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
body_result = ResultOk(result_array, true)

for _ in 1:(op.max_iters)
cond_result = _eval_tree_array(cond, X, operators, eval_options)
!cond_result.ok && return cond_result
@return_on_nonfinite_array(eval_options, cond_result.x)

new_mask = cond_result.x .> 0.0
any(new_mask) || return body_result

# Track which columns are still active
mask[mask] .= new_mask
X = @view cX[:, mask]

# Evaluate just for active columns
iter_result = _eval_tree_array(body, X, operators, eval_options)
!iter_result.ok && return iter_result

# Update the corresponding elements in the result array
body_result.x[mask] .= iter_result.x
@return_on_nonfinite_array(eval_options, body_result.x)
end

# We passed max_iters, so this result is invalid
return ResultOk(body_result.x, false)
end

end
49 changes: 42 additions & 7 deletions src/Strings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ end
end
end

const FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH = length("{FEATURE_")
function replace_feature_placeholders(s::String, f_variable::Function, variable_names)
return replace(
s,
r"\{FEATURE_(\d+)\}" =>
m -> f_variable(
parse(Int, m[(begin + FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH):(end - 1)]),
variable_names,
),
)
end

# Can overload these for custom behavior:
needs_brackets(val::Real) = false
needs_brackets(val::AbstractArray) = false
Expand Down Expand Up @@ -104,12 +116,33 @@ function combine_op_with_inputs(op, l, r)::Vector{Char}
end
end
function combine_op_with_inputs(op, l)
# "op(l)"
out = copy(op)
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
return out
# Check if this is an assignment operator with our special prefix
op_str = String(op)
if startswith(op_str, "ASSIGN_OP:")
# Extract the variable name from the operator name
var_name = op_str[11:end]
# Format: (var ← expr)
out = ['(']
append!(out, collect(var_name))
append!(out, collect(" ← "))
# Ensure the expression is always wrapped in parentheses for clarity
if l[1] == '(' && l[end] == ')'
append!(out, l)
else
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
end
push!(out, ')')
return out
else
# Regular unary operator: "op(l)"
out = copy(op)
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
return out
end
end

"""
Expand Down Expand Up @@ -179,7 +212,9 @@ function string_tree(
c
end,
)
return String(strip_brackets(raw_output))
string_output = String(strip_brackets(raw_output))
string_output = replace_feature_placeholders(string_output, f_variable, variable_names)
return string_output
end

# Print an equation
Expand Down
Loading
Loading