-
-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add implementations of CheckInit
and OverrideInit
#845
Changes from all commits
c9cb8f8
676403e
736b3d1
683f9d1
b3105cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} | |
""" | ||
initializeprob::IProb | ||
""" | ||
A function which takes `(initializeprob, prob)` and updates | ||
A function which takes `(initializeprob, value_provider)` and updates | ||
the parameters of the former with their values in the latter. | ||
If absent (`nothing`) this will not be called, and the parameters | ||
in `initializeprob` will be used without modification. `value_provider` | ||
refers to a value provider as defined by SymbolicIndexingInterface.jl. | ||
Usually this will refer to a problem or integrator. | ||
""" | ||
update_initializeprob!::UIProb | ||
""" | ||
|
@@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} | |
initializeprobmap::IProbMap | ||
""" | ||
A function which takes the solution of `initializeprob` and returns | ||
the parameter object of the original problem. | ||
the parameter object of the original problem. If absent (`nothing`), | ||
this will not be called and the parameters of the problem being | ||
initialized will be returned as-is. | ||
""" | ||
initializeprobpmap::IProbPmap | ||
|
||
|
@@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap} | |
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap) | ||
end | ||
end | ||
|
||
""" | ||
get_initial_values(prob, valp, f, alg, isinplace; kwargs...) | ||
|
||
Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm, | ||
and a boolean indicating whether the initialization process was successful. Keyword | ||
arguments to this function are dependent on the initialization algorithm. `prob` is only | ||
required for dispatching. `valp` refers the appropriate data structure from which the | ||
current state and parameter values should be obtained. `valp` is a non-timeseries value | ||
provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the | ||
problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}` | ||
if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise. | ||
""" | ||
function get_initial_values end | ||
|
||
struct CheckInitFailureError <: Exception | ||
normresid::Any | ||
abstol::Any | ||
end | ||
|
||
function Base.showerror(io::IO, e::CheckInitFailureError) | ||
print(io, | ||
"CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)") | ||
end | ||
|
||
struct OverrideInitMissingAlgorithm <: Exception end | ||
|
||
function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm) | ||
print(io, | ||
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.") | ||
end | ||
|
||
""" | ||
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if | ||
it is in-place or simply calling the function if not. | ||
""" | ||
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...) | ||
tmp = first(get_tmp_cache(integrator)) | ||
f(tmp, args...) | ||
return tmp | ||
end | ||
|
||
function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...) | ||
return f(args...) | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
|
||
A utility function equivalent to `Base.vec` but also handles `Number` and | ||
`AbstractSciMLScalarOperator`. | ||
""" | ||
_vec(v) = vec(v) | ||
_vec(v::Number) = v | ||
_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v | ||
_vec(v::AbstractVector) = v | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
|
||
Check if the algebraic constraints are satisfied, and error if they aren't. Returns | ||
the `u0` and `p` as-is, and is always successful if it returns. Valid only for | ||
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. | ||
""" | ||
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit, | ||
isinplace::Union{Val{true}, Val{false}}; kwargs...) | ||
u0 = state_values(integrator) | ||
p = parameter_values(integrator) | ||
t = current_time(integrator) | ||
M = f.mass_matrix | ||
|
||
algebraic_vars = [all(iszero, x) for x in eachcol(M)] | ||
algebraic_eqs = [all(iszero, x) for x in eachrow(M)] | ||
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return | ||
update_coefficients!(M, u0, p, t) | ||
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t) | ||
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This allocates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somehow changing this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least, according to |
||
|
||
normresid = integrator.opts.internalnorm(tmp, t) | ||
if normresid > integrator.opts.abstol | ||
throw(CheckInitFailureError(normresid, integrator.opts.abstol)) | ||
end | ||
return u0, p, true | ||
end | ||
|
||
""" | ||
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if | ||
it is in-place or simply calling the function if not. | ||
""" | ||
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...) | ||
tmp = get_tmp_cache(integrator)[2] | ||
f(tmp, args...) | ||
return tmp | ||
end | ||
|
||
function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...) | ||
return f(args...) | ||
end | ||
|
||
function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit, | ||
isinplace::Union{Val{true}, Val{false}}; kwargs...) | ||
u0 = state_values(integrator) | ||
p = parameter_values(integrator) | ||
t = current_time(integrator) | ||
|
||
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t) | ||
normresid = integrator.opts.internalnorm(resid, t) | ||
if normresid > integrator.opts.abstol | ||
throw(CheckInitFailureError(normresid, integrator.opts.abstol)) | ||
end | ||
return u0, p, true | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
|
||
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and | ||
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`. | ||
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is. | ||
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword | ||
argument, failing which this function will throw an error. The success value returned | ||
depends on the success of the nonlinear solve. | ||
""" | ||
function get_initial_values(prob, valp, f, alg::OverrideInit, | ||
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...) | ||
u0 = state_values(valp) | ||
p = parameter_values(valp) | ||
|
||
if !has_initialization_data(f) | ||
return u0, p, true | ||
end | ||
|
||
initdata::OverrideInitData = f.initialization_data | ||
initprob = initdata.initializeprob | ||
|
||
if nlsolve_alg === nothing | ||
throw(OverrideInitMissingAlgorithm()) | ||
end | ||
|
||
if initdata.update_initializeprob! !== nothing | ||
initdata.update_initializeprob!(initprob, valp) | ||
end | ||
|
||
nlsol = solve(initprob, nlsolve_alg) | ||
|
||
u0 = initdata.initializeprobmap(nlsol) | ||
if initdata.initializeprobpmap !== nothing | ||
p = initdata.initializeprobpmap(nlsol) | ||
end | ||
|
||
return u0, p, SciMLBase.successful_retcode(nlsol) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test | ||
|
||
@testset "CheckInit" begin | ||
@testset "ODEProblem" begin | ||
function rhs(u, p, t) | ||
return [u[1] * t, u[1]^2 - u[2]^2] | ||
end | ||
function rhs!(du, u, p, t) | ||
du[1] = u[1] * t | ||
du[2] = u[1]^2 - u[2]^2 | ||
end | ||
|
||
oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0]) | ||
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0]) | ||
|
||
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] | ||
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) | ||
integ = init(prob) | ||
u0, _, success = SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) | ||
@test success | ||
@test u0 == prob.u0 | ||
|
||
integ.u[2] = 2.0 | ||
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) | ||
end | ||
end | ||
|
||
@testset "DAEProblem" begin | ||
function daerhs(du, u, p, t) | ||
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2] | ||
end | ||
function daerhs!(resid, du, u, p, t) | ||
resid[1] = du[1] - u[1] * t - p | ||
resid[2] = u[1]^2 - u[2]^2 | ||
end | ||
|
||
oopfn = DAEFunction{false}(daerhs) | ||
iipfn = DAEFunction{true}(daerhs!) | ||
|
||
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] | ||
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0) | ||
integ = init(prob, DImplicitEuler()) | ||
u0, _, success = SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) | ||
@test success | ||
@test u0 == prob.u0 | ||
|
||
integ.u[2] = 2.0 | ||
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) | ||
|
||
integ.u[2] = 1.0 | ||
integ.du[1] = 2.0 | ||
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f))) | ||
end | ||
end | ||
end | ||
|
||
@testset "OverrideInit" begin | ||
function rhs2(u, p, t) | ||
return [u[1] * t + p, u[1]^2 - u[2]^2] | ||
end | ||
|
||
@testset "No-op without `initialization_data`" begin | ||
prob = ODEProblem(rhs2, [1.0, 2.0], (0.0, 1.0), 1.0) | ||
integ = init(prob) | ||
integ.u[2] = 3.0 | ||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(false)) | ||
@test u0 ≈ [1.0, 3.0] | ||
@test success | ||
end | ||
|
||
# unknowns are u[2], p. Parameter is u[1] | ||
initprob = NonlinearProblem([1.0, 1.0], [1.0]) do x, _u1 | ||
u2, p = x | ||
u1 = _u1[1] | ||
return [u1^2 - u2^2, p^2 - 2p + 1] | ||
end | ||
update_initializeprob! = function (iprob, integ) | ||
iprob.p[1] = integ.u[1] | ||
end | ||
initprobmap = function (nlsol) | ||
return [parameter_values(nlsol)[1], nlsol.u[1]] | ||
end | ||
initprobpmap = function (nlsol) | ||
return nlsol.u[2] | ||
end | ||
initialization_data = SciMLBase.OverrideInitData( | ||
initprob, update_initializeprob!, initprobmap, initprobpmap) | ||
fn = ODEFunction(rhs2; initialization_data) | ||
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) | ||
integ = init(prob; initializealg = NoInit()) | ||
|
||
@testset "Errors without `nlsolve_alg`" begin | ||
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values( | ||
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)) | ||
end | ||
|
||
@testset "Solves" begin | ||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, integ, fn, SciMLBase.OverrideInit(), | ||
Val(false); nlsolve_alg = NewtonRaphson()) | ||
|
||
@test u0 ≈ [2.0, 2.0] | ||
@test p ≈ 1.0 | ||
@test success | ||
|
||
initprob.p[1] = 1.0 | ||
end | ||
|
||
@testset "Solves with non-integrator value provider" begin | ||
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t) | ||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, _integ, fn, SciMLBase.OverrideInit(), | ||
Val(false); nlsolve_alg = NewtonRaphson()) | ||
|
||
@test u0 ≈ [2.0, 2.0] | ||
@test p ≈ 1.0 | ||
@test success | ||
|
||
initprob.p[1] = 1.0 | ||
end | ||
|
||
@testset "Solves without `update_initializeprob!`" begin | ||
initdata = SciMLBase.@set initialization_data.update_initializeprob! = nothing | ||
fn = ODEFunction(rhs2; initialization_data = initdata) | ||
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) | ||
integ = init(prob; initializealg = NoInit()) | ||
|
||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, integ, fn, SciMLBase.OverrideInit(), | ||
Val(false); nlsolve_alg = NewtonRaphson()) | ||
@test u0 ≈ [1.0, 1.0] | ||
@test p ≈ 1.0 | ||
@test success | ||
end | ||
|
||
@testset "Solves without `initializeprobpmap`" begin | ||
initdata = SciMLBase.@set initialization_data.initializeprobpmap = nothing | ||
fn = ODEFunction(rhs2; initialization_data = initdata) | ||
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) | ||
integ = init(prob; initializealg = NoInit()) | ||
|
||
u0, p, success = SciMLBase.get_initial_values( | ||
prob, integ, fn, SciMLBase.OverrideInit(), | ||
Val(false); nlsolve_alg = NewtonRaphson()) | ||
|
||
@test u0 ≈ [2.0, 2.0] | ||
@test p ≈ 0.0 | ||
@test success | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should use a cached one from the integrator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cached what, exactly? 😅 I copied almost all of this from the existing implementation