diff --git a/src/remake.jl b/src/remake.jl index 8ed91237f..750413a1f 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing) return default end +""" + $(TYPEDSIGNATURES) + +Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in +`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default. +""" +function detect_cycles(indp, varmap, syms) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && + (sc = symbolic_container(indp)) != indp + return detect_cycles(sc, varmap, syms) + else + return false + end +end + anydict(d::Dict{Any, Any}) = d anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() @@ -560,14 +575,24 @@ function _updated_u0_p_internal( end function fill_u0(prob, u0; defs = nothing, use_defaults = false) - vsyms = variable_symbols(prob) - idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms) + fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob), + index_function = variable_index) +end + +function fill_p(prob, p; defs = nothing, use_defaults = false) + fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob), + index_function = parameter_index) +end + +function fill_vars( + prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function) + idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms) sym_to_idx = anydict() idx_to_sym = anydict() idx_to_val = anydict() - for (k, v) in u0 + for (k, v) in varmap v === nothing && continue - idx = variable_index(prob, k) + idx = index_function(prob, k) idx === nothing && continue if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() idx = (idx,) @@ -582,9 +607,9 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) idx_to_val[ii] = vv end end - for sym in vsyms + for sym in allsyms haskey(sym_to_idx, sym) && continue - idx = variable_index(prob, sym) + idx = index_function(prob, sym) haskey(idx_to_val, idx) && continue sym_to_idx[sym] = idx idx_to_sym[idx] = sym @@ -600,65 +625,33 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) for (idx, val) in idx_to_val newvals[idx_to_sym[idx]] = val end - for (k, v) in u0 + for (k, v) in varmap haskey(sym_to_idx, k) && continue newvals[k] = v end return newvals end -function fill_p(prob, p; defs = nothing, use_defaults = false) - psyms = parameter_symbols(prob) - idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms) - sym_to_idx = anydict() - idx_to_sym = anydict() - idx_to_val = anydict() - for (k, v) in p - v === nothing && continue - idx = parameter_index(prob, k) - idx === nothing && continue - if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() - idx = (idx,) - k = (k,) - v = (v,) - end - for (kk, vv, ii) in zip(k, v, idx) - sym_to_idx[kk] = ii - kk = idx_to_psym[ii] - sym_to_idx[kk] = ii - idx_to_sym[ii] = kk - idx_to_val[ii] = vv - end - end - for sym in psyms - haskey(sym_to_idx, sym) && continue - idx = parameter_index(prob, sym) - haskey(idx_to_val, idx) && continue - sym_to_idx[sym] = idx - idx_to_sym[idx] = sym - idx_to_val[idx] = if defs !== nothing && - (defval = varmap_get(defs, sym)) !== nothing && - (symbolic_type(defval) != NotSymbolic() || use_defaults) - defval - else - getp(prob, sym)(prob) - end - end - newvals = anydict() - for (idx, val) in idx_to_val - newvals[idx_to_sym[idx]] = val - end - for (k, v) in p - haskey(sym_to_idx, k) && continue - newvals[k] = v +struct CyclicDependencyError <: Exception + varmap::Dict{Any, Any} + vars::Any +end + +function Base.showerror(io::IO, err::CyclicDependencyError) + println(io, "Detected cyclic dependency in initial values:") + for (k, v) in err.varmap + println(io, k, " => ", "v") end - return newvals + println(io, "While trying to solve for variables: ", err.vars) end function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p + if detect_cycles(prob, u0, variable_symbols(prob)) + throw(CyclicDependencyError(u0, variable_symbols(prob))) + end for (k, v) in u0 u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0) end @@ -680,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) + if detect_cycles(prob, p, parameter_symbols(prob)) + throw(CyclicDependencyError(p, parameter_symbols(prob))) + end for (k, v) in p p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p) end @@ -707,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) end varmap = merge(u0, p) + allsyms = [variable_symbols(prob); parameter_symbols(prob)] + if detect_cycles(prob, varmap, allsyms) + throw(CyclicDependencyError(varmap, allsyms)) + end if is_time_dependent(prob) varmap[only(independent_variable_symbols(prob))] = t0 end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f86864caf..1d0041cea 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -209,7 +209,7 @@ indices that can be plotted as continuous variables. This is useful for systems that store auxiliary variables in the state vector which are not meant to be used for plotting. """ -plottable_indices(x:: AbstractArray) = 1:length(x) +plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 @recipe function f(sol::AbstractTimeseriesSolution; diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index a10dd09c7..228a26dd6 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -248,3 +248,29 @@ end prob2 = remake(prob; u0 = [x => t + 3.0]) @test prob2[x] ≈ 3.0 end + +@static if length(methods(SciMLBase.detect_cycles)) == 1 + function SciMLBase.detect_cycles( + ::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars) + for sym in vars + if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) != + NotSymbolic() + return true + end + end + return false + end +end + +@testset "Cycle detection" begin + @variables x(t) y(t) + @parameters p q + @mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t) + + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), [p => 1.0, q => 1.0]) + @test_throws SciMLBase.CyclicDependencyError remake( + prob; u0 = [x => 2y + 3, y => 2x + 1]) + @test_throws SciMLBase.CyclicDependencyError remake(prob; p = [p => 2q + 1, q => p + 3]) + @test_throws SciMLBase.CyclicDependencyError remake( + prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3]) +end