Skip to content

Commit

Permalink
Merge pull request #852 from AayushSabharwal/as/remake-debugging
Browse files Browse the repository at this point in the history
feat: add cycle detection in initial conditions
  • Loading branch information
ChrisRackauckas authored Nov 8, 2024
2 parents 0453006 + b3949ed commit d366481
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 53 deletions.
104 changes: 52 additions & 52 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing)
return default
end

"""
$(TYPEDSIGNATURES)
Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in
`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default.
"""
function detect_cycles(indp, varmap, syms)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) != indp
return detect_cycles(sc, varmap, syms)
else
return false
end
end

anydict(d::Dict{Any, Any}) = d
anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()
Expand Down Expand Up @@ -560,14 +575,24 @@ function _updated_u0_p_internal(
end

function fill_u0(prob, u0; defs = nothing, use_defaults = false)
vsyms = variable_symbols(prob)
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob),
index_function = variable_index)
end

function fill_p(prob, p; defs = nothing, use_defaults = false)
fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob),
index_function = parameter_index)
end

function fill_vars(
prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function)
idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
for (k, v) in u0
for (k, v) in varmap
v === nothing && continue
idx = variable_index(prob, k)
idx = index_function(prob, k)
idx === nothing && continue
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
Expand All @@ -582,9 +607,9 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
idx_to_val[ii] = vv
end
end
for sym in vsyms
for sym in allsyms
haskey(sym_to_idx, sym) && continue
idx = variable_index(prob, sym)
idx = index_function(prob, sym)
haskey(idx_to_val, idx) && continue
sym_to_idx[sym] = idx
idx_to_sym[idx] = sym
Expand All @@ -600,65 +625,33 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
for (idx, val) in idx_to_val
newvals[idx_to_sym[idx]] = val
end
for (k, v) in u0
for (k, v) in varmap
haskey(sym_to_idx, k) && continue
newvals[k] = v
end
return newvals
end

function fill_p(prob, p; defs = nothing, use_defaults = false)
psyms = parameter_symbols(prob)
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
for (k, v) in p
v === nothing && continue
idx = parameter_index(prob, k)
idx === nothing && continue
if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic()
idx = (idx,)
k = (k,)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
kk = idx_to_psym[ii]
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
end
end
for sym in psyms
haskey(sym_to_idx, sym) && continue
idx = parameter_index(prob, sym)
haskey(idx_to_val, idx) && continue
sym_to_idx[sym] = idx
idx_to_sym[idx] = sym
idx_to_val[idx] = if defs !== nothing &&
(defval = varmap_get(defs, sym)) !== nothing &&
(symbolic_type(defval) != NotSymbolic() || use_defaults)
defval
else
getp(prob, sym)(prob)
end
end
newvals = anydict()
for (idx, val) in idx_to_val
newvals[idx_to_sym[idx]] = val
end
for (k, v) in p
haskey(sym_to_idx, k) && continue
newvals[k] = v
struct CyclicDependencyError <: Exception
varmap::Dict{Any, Any}
vars::Any
end

function Base.showerror(io::IO, err::CyclicDependencyError)
println(io, "Detected cyclic dependency in initial values:")
for (k, v) in err.varmap
println(io, k, " => ", "v")
end
return newvals
println(io, "While trying to solve for variables: ", err.vars)
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

if detect_cycles(prob, u0, variable_symbols(prob))
throw(CyclicDependencyError(u0, variable_symbols(prob)))
end
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
end
Expand All @@ -680,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

if detect_cycles(prob, p, parameter_symbols(prob))
throw(CyclicDependencyError(p, parameter_symbols(prob)))
end
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
end
Expand Down Expand Up @@ -707,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
end

varmap = merge(u0, p)
allsyms = [variable_symbols(prob); parameter_symbols(prob)]
if detect_cycles(prob, varmap, allsyms)
throw(CyclicDependencyError(varmap, allsyms))
end
if is_time_dependent(prob)
varmap[only(independent_variable_symbols(prob))] = t0
end
Expand Down
2 changes: 1 addition & 1 deletion src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ indices that can be plotted as continuous variables. This is useful for systems
that store auxiliary variables in the state vector which are not meant to be
used for plotting.
"""
plottable_indices(x:: AbstractArray) = 1:length(x)
plottable_indices(x::AbstractArray) = 1:length(x)
plottable_indices(x::Number) = 1

@recipe function f(sol::AbstractTimeseriesSolution;
Expand Down
26 changes: 26 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,29 @@ end
prob2 = remake(prob; u0 = [x => t + 3.0])
@test prob2[x] 3.0
end

@static if length(methods(SciMLBase.detect_cycles)) == 1
function SciMLBase.detect_cycles(
::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars)
for sym in vars
if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) !=
NotSymbolic()
return true
end
end
return false
end
end

@testset "Cycle detection" begin
@variables x(t) y(t)
@parameters p q
@mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t)

prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), [p => 1.0, q => 1.0])
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + 3, y => 2x + 1])
@test_throws SciMLBase.CyclicDependencyError remake(prob; p = [p => 2q + 1, q => p + 3])
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
end

0 comments on commit d366481

Please sign in to comment.