Skip to content

Commit

Permalink
Merge pull request #837 from AayushSabharwal/as/remake-fix
Browse files Browse the repository at this point in the history
fix: fix remake with u0 dependent on `Symbol` parameter
  • Loading branch information
ChrisRackauckas authored Oct 30, 2024
2 parents dd0da91 + 95c7fff commit babd239
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 9 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -116,4 +114,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
9 changes: 9 additions & 0 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ end

function fill_u0(prob, u0; defs = nothing, use_defaults = false)
vsyms = variable_symbols(prob)
idx_to_vsym = anydict(variable_index(prob, sym) => sym for sym in vsyms)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
Expand All @@ -580,6 +581,8 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
kk = idx_to_vsym[ii]
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
Expand Down Expand Up @@ -612,6 +615,7 @@ end

function fill_p(prob, p; defs = nothing, use_defaults = false)
psyms = parameter_symbols(prob)
idx_to_psym = anydict(parameter_index(prob, sym) => sym for sym in psyms)
sym_to_idx = anydict()
idx_to_sym = anydict()
idx_to_val = anydict()
Expand All @@ -625,6 +629,8 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
v = (v,)
end
for (kk, vv, ii) in zip(k, v, idx)
sym_to_idx[kk] = ii
kk = idx_to_psym[ii]
sym_to_idx[kk] = ii
idx_to_sym[ii] = kk
idx_to_val[ii] = vv
Expand Down Expand Up @@ -707,6 +713,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
end

varmap = merge(u0, p)
if is_time_dependent(prob)
varmap[only(independent_variable_symbols(prob))] = t0
end
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
end
Expand Down
19 changes: 19 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,22 @@ end
prob2 = remake(prob; u0 = [y => 3.0], p = Dict())
@test prob2.ps[p] 4.0
end

@testset "u0 dependent on parameter given as Symbol" begin
@variables x(t)
@parameters p
@mtkbuild sys = ODESystem([D(x) ~ x * p], t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
@test prob.ps[p] 1.0
prob2 = remake(prob; u0 = [x => p], p = [:p => 2.0])
@test prob2[x] 2.0
end

@testset "remake dependent on indepvar" begin
@variables x(t)
@parameters p
@mtkbuild sys = ODESystem([D(x) ~ x * p], t)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
prob2 = remake(prob; u0 = [x => t + 3.0])
@test prob2[x] 3.0
end
2 changes: 1 addition & 1 deletion test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ end
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r])
@test SciMLBase.get_saved_state_idxs(ss) == [xidx]
@test SciMLBase.get_saved_state_idxs(ss) == [variable_index(sys, x)]
sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys)
xidx = variable_index(sys, x)
qidx = parameter_index(sys, q)
Expand Down
38 changes: 33 additions & 5 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,26 +297,54 @@ a = Remake_Test1(p = 1)
@test @inferred remake(a, kwargs = (; a = 1)) == Remake_Test1(p = 1, a = 1)

@testset "fill_u0 and fill_p ignore identical variables with different names" begin
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
struct SCWrapper{S}
sc::S
end
SymbolicIndexingInterface.symbolic_container(s::SCWrapper) = s.sc
function SymbolicIndexingInterface.is_variable(s::SCWrapper, i::Symbol)
if i == :x2
return is_variable(s.sc, :x)
end
is_variable(s.sc, i)
end
function SymbolicIndexingInterface.variable_index(s::SCWrapper, i::Symbol)
if i == :x2
return variable_index(s.sc, :x)
end
variable_index(s.sc, i)
end
function SymbolicIndexingInterface.is_parameter(s::SCWrapper, i::Symbol)
if i == :a2
return is_parameter(s.sc, :a)
end
is_parameter(s.sc, i)
end
function SymbolicIndexingInterface.parameter_index(s::SCWrapper, i::Symbol)
if i == :a2
return parameter_index(s.sc, :a)
end
parameter_index(s.sc, i)
end
sys = SCWrapper(SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2),
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4)))
function foo(du, u, p, t)
du .= u .* p
end
prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5])
u0 = Dict(:x2 => 2)
newu0 = SciMLBase.fill_u0(prob, u0; defs = default_values(sys))
@test length(newu0) == 2
@test get(newu0, :x2, 0) == 2
@test get(newu0, :x, 0) == 2
@test get(newu0, :y, 0) == 2.5
p = Dict(:a2 => 3)
newp = SciMLBase.fill_p(prob, p; defs = default_values(sys))
@test length(newp) == 2
@test get(newp, :a2, 0) == 3
@test get(newp, :a, 0) == 3
@test get(newp, :b, 0) == 4.5
end

@testset "value of `nothing` is ignored" begin
sys = SymbolCache(Dict(:x => 1, :x2 => 1, :y => 2), Dict(:a => 1, :a2 => 1, :b => 2),
sys = SymbolCache(Dict(:x => 1, :y => 2), Dict(:a => 1, :b => 2),
:t; defaults = Dict(:x => 1, :y => 2, :a => 3, :b => 4))
function foo(du, u, p, t)
du .= u .* p
Expand Down

0 comments on commit babd239

Please sign in to comment.