Skip to content

Commit

Permalink
Merge pull request #832 from AayushSabharwal/as/remake-fix
Browse files Browse the repository at this point in the history
fix: fix `remake` for parameters dependent on observed variables
  • Loading branch information
ChrisRackauckas authored Oct 28, 2024
2 parents d3611ca + 3a25351 commit bf913e7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
39 changes: 19 additions & 20 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ function varmap_get(varmap, var, default = nothing)
return default
end

anydict(d::Dict{Any, Any}) = d
anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()

Expand Down Expand Up @@ -658,8 +659,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
for (k, v) in u0)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
end

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
Expand All @@ -668,17 +670,19 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
# used, since any state symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in u0)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
end
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
end

function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
for (k, v) in p)
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
end

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
Expand All @@ -687,8 +691,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
# used, since any parameter symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in p)
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
end
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end

Expand All @@ -700,20 +705,14 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end
if !isu0dep
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
end
if !ispdep
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
end

varmap = merge(u0, p)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in u0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in p)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
end
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
end
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end
Expand Down
10 changes: 10 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,13 @@ end
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
@test newoprob[V] == [1.5, 2.5]
end

@testset "remake with parameter dependent on observed" begin
@variables x(t) y(t)
@parameters p = x + y
@mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t)
prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0))
@test prob.ps[p] 3.0
prob2 = remake(prob; u0 = [y => 3.0], p = Dict())
@test prob2.ps[p] 4.0
end

0 comments on commit bf913e7

Please sign in to comment.