From af35520083eb8e6218b34af09225448e0bc1c8e9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Nov 2024 13:22:50 +0530 Subject: [PATCH 1/4] refactor: format --- src/solutions/solution_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 7cca426f5..41a2681f3 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -209,7 +209,7 @@ indices that can be plotted as continuous variables. This is useful for systems that store auxiliary variables in the state vector which are not meant to be used for plotting. """ -plottable_indices(x:: AbstractArray) = 1:length(x) +plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 @recipe function f(sol::AbstractTimeseriesSolution; From e4e7c8b032729a3878b72399cafdda960d956e97 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Nov 2024 13:22:39 +0530 Subject: [PATCH 2/4] refactor: merge `fill_u0` and `fill_p` implementations --- src/remake.jl | 70 ++++++++++++--------------------------------------- 1 file changed, 16 insertions(+), 54 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 0ffd91003..a4437c207 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -560,62 +560,24 @@ function _updated_u0_p_internal( end function fill_u0(prob, u0; defs = nothing, use_defaults = false) - vsyms = variable_symbols(prob) - idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms) - sym_to_idx = anydict() - idx_to_sym = anydict() - idx_to_val = anydict() - for (k, v) in u0 - v === nothing && continue - idx = variable_index(prob, k) - idx === nothing && continue - if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() - idx = (idx,) - k = (k,) - v = (v,) - end - for (kk, vv, ii) in zip(k, v, idx) - sym_to_idx[kk] = ii - kk = idx_to_vsym[ii] - sym_to_idx[kk] = ii - idx_to_sym[ii] = kk - idx_to_val[ii] = vv - end - end - for sym in vsyms - haskey(sym_to_idx, sym) && continue - idx = variable_index(prob, sym) - haskey(idx_to_val, idx) && continue - sym_to_idx[sym] = idx - idx_to_sym[idx] = sym - idx_to_val[idx] = if defs !== nothing && - (defval = varmap_get(defs, sym)) !== nothing && - (symbolic_type(defval) != NotSymbolic() || use_defaults) - defval - else - getu(prob, sym)(prob) - end - end - newvals = anydict() - for (idx, val) in idx_to_val - newvals[idx_to_sym[idx]] = val - end - for (k, v) in u0 - haskey(sym_to_idx, k) && continue - newvals[k] = v - end - return newvals + fill_vars(prob, u0; defs, use_defaults, allsyms = variable_symbols(prob), + index_function = variable_index) end function fill_p(prob, p; defs = nothing, use_defaults = false) - psyms = parameter_symbols(prob) - idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms) + fill_vars(prob, p; defs, use_defaults, allsyms = parameter_symbols(prob), + index_function = parameter_index) +end + +function fill_vars( + prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function) + idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in allsyms) sym_to_idx = anydict() idx_to_sym = anydict() idx_to_val = anydict() - for (k, v) in p + for (k, v) in varmap v === nothing && continue - idx = parameter_index(prob, k) + idx = index_function(prob, k) idx === nothing && continue if !(idx isa AbstractArray) || symbolic_type(k) != ArraySymbolic() idx = (idx,) @@ -624,15 +586,15 @@ function fill_p(prob, p; defs = nothing, use_defaults = false) end for (kk, vv, ii) in zip(k, v, idx) sym_to_idx[kk] = ii - kk = idx_to_psym[ii] + kk = idx_to_vsym[ii] sym_to_idx[kk] = ii idx_to_sym[ii] = kk idx_to_val[ii] = vv end end - for sym in psyms + for sym in allsyms haskey(sym_to_idx, sym) && continue - idx = parameter_index(prob, sym) + idx = index_function(prob, sym) haskey(idx_to_val, idx) && continue sym_to_idx[sym] = idx idx_to_sym[idx] = sym @@ -641,14 +603,14 @@ function fill_p(prob, p; defs = nothing, use_defaults = false) (symbolic_type(defval) != NotSymbolic() || use_defaults) defval else - getp(prob, sym)(prob) + getsym(prob, sym)(prob) end end newvals = anydict() for (idx, val) in idx_to_val newvals[idx_to_sym[idx]] = val end - for (k, v) in p + for (k, v) in varmap haskey(sym_to_idx, k) && continue newvals[k] = v end From 0b7d8fac3568ba6db4b2ff933f6fe50a89b8690d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Nov 2024 14:01:19 +0530 Subject: [PATCH 3/4] feat: add cycle detection in initial conditions --- src/remake.jl | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/src/remake.jl b/src/remake.jl index a4437c207..43812df11 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -509,6 +509,21 @@ function varmap_get(varmap, var, default = nothing) return default end +""" + $(TYPEDSIGNATURES) + +Check if `varmap::Dict{Any, Any}` contains cyclic values for any symbolic variables in +`syms`. Falls back on the basis of `symbolic_container(indp)`. Returns `false` by default. +""" +function detect_cycles(indp, varmap, syms) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) && + (sc = symbolic_container(indp)) != indp + return detect_cycles(sc, varmap, syms) + else + return false + end +end + anydict(d::Dict{Any, Any}) = d anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() @@ -571,7 +586,7 @@ end function fill_vars( prob, varmap; defs = nothing, use_defaults = false, allsyms, index_function) - idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in allsyms) + idx_to_vsym = anydict(index_function(prob, sym) => sym for sym in allsyms) sym_to_idx = anydict() idx_to_sym = anydict() idx_to_val = anydict() @@ -617,10 +632,26 @@ function fill_vars( return newvals end +struct CyclicDependencyError <: Exception + varmap::Dict{Any, Any} + vars::Any +end + +function Base.showerror(io::IO, err::CyclicDependencyError) + println(io, "Detected cyclic dependency in initial values:") + for (k, v) in err.varmap + println(io, k, " => ", "v") + end + println(io, "While trying to solve for variables: ", err.vars) +end + function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0) isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p + if detect_cycles(prob, u0, variable_symbols(prob)) + throw(CyclicDependencyError(u0, variable_symbols(prob))) + end for (k, v) in u0 u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0) end @@ -642,6 +673,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p) isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) + if detect_cycles(prob, p, parameter_symbols(prob)) + throw(CyclicDependencyError(p, parameter_symbols(prob))) + end for (k, v) in p p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p) end @@ -669,6 +703,10 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0) end varmap = merge(u0, p) + allsyms = [variable_symbols(prob); parameter_symbols(prob)] + if detect_cycles(prob, varmap, allsyms) + throw(CyclicDependencyError(varmap, allsyms)) + end if is_time_dependent(prob) varmap[only(independent_variable_symbols(prob))] = t0 end From b3949edfffeebe15af1c7962222a621ae8eb6d0f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Nov 2024 15:28:25 +0530 Subject: [PATCH 4/4] test: add tests for cycle detection --- test/downstream/modelingtoolkit_remake.jl | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index d0df44658..75fcec09d 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -248,3 +248,29 @@ end prob2 = remake(prob; u0 = [x => t + 3.0]) @test prob2[x] ≈ 3.0 end + +@static if length(methods(SciMLBase.detect_cycles)) == 1 + function SciMLBase.detect_cycles( + ::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars) + for sym in vars + if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) != + NotSymbolic() + return true + end + end + return false + end +end + +@testset "Cycle detection" begin + @variables x(t) y(t) + @parameters p q + @mtkbuild sys = ODESystem([D(x) ~ x * p, D(y) ~ y * q], t) + + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), [p => 1.0, q => 1.0]) + @test_throws SciMLBase.CyclicDependencyError remake( + prob; u0 = [x => 2y + 3, y => 2x + 1]) + @test_throws SciMLBase.CyclicDependencyError remake(prob; p = [p => 2q + 1, q => p + 3]) + @test_throws SciMLBase.CyclicDependencyError remake( + prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3]) +end