Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hierarchical printing of graph-like expressions #59

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("Utils.jl")
include("OperatorEnum.jl")
include("Equation.jl")
include("EquationUtils.jl")
include("Strings.jl")
include("EvaluateEquation.jl")
include("EvaluateEquationDerivative.jl")
include("EvaluationHelpers.jl")
Expand All @@ -19,8 +20,6 @@ import Reexport: @reexport
AbstractExpressionNode,
GraphNode,
Node,
string_tree,
print_tree,
copy_node,
set_node!,
tree_mapreduce,
Expand All @@ -37,6 +36,7 @@ import .EquationModule: constructorof, preserve_sharing
has_constants,
get_constants,
set_constants!
@reexport import .StringsModule: string_tree, print_tree, pretty_string_graph
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
Expand Down
177 changes: 0 additions & 177 deletions src/Equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,181 +324,4 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod
return nothing
end

const OP_NAMES = Base.ImmutableDict(
"safe_log" => "log",
"safe_log2" => "log2",
"safe_log10" => "log10",
"safe_log1p" => "log1p",
"safe_acosh" => "acosh",
"safe_sqrt" => "sqrt",
"safe_pow" => "^",
)

function dispatch_op_name(::Val{2}, ::Nothing, idx)::Vector{Char}
return vcat(collect("binary_operator["), collect(string(idx)), [']'])
end
function dispatch_op_name(::Val{1}, ::Nothing, idx)::Vector{Char}
return vcat(collect("unary_operator["), collect(string(idx)), [']'])
end
function dispatch_op_name(::Val{2}, operators::AbstractOperatorEnum, idx)::Vector{Char}
return get_op_name(operators.binops[idx])
end
function dispatch_op_name(::Val{1}, operators::AbstractOperatorEnum, idx)::Vector{Char}
return get_op_name(operators.unaops[idx])
end

@generated function get_op_name(op::F)::Vector{Char} where {F}
try
# Bit faster to just cache the name of the operator:
op_s = string(F.instance)
out = collect(get(OP_NAMES, op_s, op_s))
return :($out)
catch
end
return quote
op_s = string(op)
out = collect(get(OP_NAMES, op_s, op_s))
return out
end
end

@inline function strip_brackets(s::Vector{Char})::Vector{Char}
if first(s) == '(' && last(s) == ')'
return s[(begin + 1):(end - 1)]
else
return s
end
end

# Can overload these for custom behavior:
needs_brackets(val::Real) = false
needs_brackets(val::AbstractArray) = false
needs_brackets(val::Complex) = true
needs_brackets(val) = true

function string_constant(val)
if needs_brackets(val)
'(' * string(val) * ')'
else
string(val)
end
end

function string_variable(feature, variable_names)
if variable_names === nothing || feature > lastindex(variable_names)
return 'x' * string(feature)
else
return variable_names[feature]
end
end

# Vector of chars is faster than strings, so we use that.
function combine_op_with_inputs(op, l, r)::Vector{Char}
if first(op) in ('+', '-', '*', '/', '^')
# "(l op r)"
out = ['(']
append!(out, l)
push!(out, ' ')
append!(out, op)
push!(out, ' ')
append!(out, r)
push!(out, ')')
else
# "op(l, r)"
out = copy(op)
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ',')
push!(out, ' ')
append!(out, strip_brackets(r))
push!(out, ')')
return out
end
end
function combine_op_with_inputs(op, l)
# "op(l)"
out = copy(op)
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
return out
end

"""
string_tree(
tree::AbstractExpressionNode{T},
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
f_variable::F1=string_variable,
f_constant::F2=string_constant,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
)::String where {T,F1<:Function,F2<:Function}

Convert an equation to a string.

# Arguments
- `tree`: the tree to convert to a string
- `operators`: the operators used to define the tree

# Keyword Arguments
- `f_variable`: (optional) function to convert a variable to a string, with arguments `(feature::UInt8, variable_names)`.
- `f_constant`: (optional) function to convert a constant to a string, with arguments `(val,)`
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature.
"""
function string_tree(
tree::AbstractExpressionNode{T},
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
f_variable::F1=string_variable,
f_constant::F2=string_constant,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
)::String where {T,F1<:Function,F2<:Function}
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
raw_output = tree_mapreduce(
leaf -> if leaf.constant
collect(f_constant(leaf.val::T))
else
collect(f_variable(leaf.feature, variable_names))
end,
branch -> if branch.degree == 1
dispatch_op_name(Val(1), operators, branch.op)
else
dispatch_op_name(Val(2), operators, branch.op)
end,
combine_op_with_inputs,
tree,
Vector{Char};
f_on_shared=(c, is_shared) -> if is_shared
out = ['{']
append!(out, c)
push!(out, '}')
out
else
c
end,
)
return String(strip_brackets(raw_output))
end

# Print an equation
for io in ((), (:(io::IO),))
@eval function print_tree(
$(io...),
tree::AbstractExpressionNode,
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
f_variable::F1=string_variable,
f_constant::F2=string_constant,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
) where {F1<:Function,F2<:Function}
variable_names = deprecate_varmap(variable_names, varMap, :print_tree)
return println(
$(io...), string_tree(tree, operators; f_variable, f_constant, variable_names)
)
end
end

end
3 changes: 2 additions & 1 deletion src/EvaluateEquation.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module EvaluateEquationModule

import LoopVectorization: @turbo, indices
import ..EquationModule: AbstractExpressionNode, constructorof, string_tree
import ..EquationModule: AbstractExpressionNode, constructorof
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array, fill_similar
import ..StringsModule: string_tree
import ..EquationUtilsModule: is_constant

macro return_on_check(val, X)
Expand Down
3 changes: 2 additions & 1 deletion src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module OperatorEnumConstructionModule

import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: string_tree, Node, GraphNode, AbstractExpressionNode, constructorof
import ..EquationModule: Node, GraphNode, AbstractExpressionNode, constructorof
import ..StringsModule: string_tree
import ..EvaluateEquationModule: eval_tree_array
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
import ..EvaluationHelpersModule: _grad_evaluator
Expand Down
Loading