Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add implementations of CheckInit and OverrideInit #845

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ LinearAlgebra = "1.10"
Logging = "1.10"
Makie = "0.20, 0.21"
Markdown = "1.10"
NonlinearSolve = "3, 4"
PartialFunctions = "1.1"
PrecompileTools = "1.2"
Preferences = "1.3"
Expand Down Expand Up @@ -98,6 +99,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -114,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "NonlinearSolve", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
7 changes: 6 additions & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ $(TYPEDEF)
"""
struct CheckInit <: DAEInitializationAlgorithm end

"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end

# PDE Discretizations

"""
Expand Down Expand Up @@ -654,7 +659,6 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
struct TrackerOriginator <: ADOriginator end

include("utils.jl")
include("initialization.jl")
include("function_wrappers.jl")
include("scimlfunctions.jl")
include("alg_traits.jl")
Expand Down Expand Up @@ -740,6 +744,7 @@ include("ensemble/ensemble_problems.jl")
include("ensemble/basic_ensemble_solve.jl")
include("ensemble/ensemble_analysis.jl")

include("initialization.jl")
include("solve.jl")
include("interpolation.jl")
include("integrator_interface.jl")
Expand Down
162 changes: 160 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
initializeprob::IProb
"""
A function which takes `(initializeprob, prob)` and updates
A function which takes `(initializeprob, value_provider)` and updates
the parameters of the former with their values in the latter.
If absent (`nothing`) this will not be called, and the parameters
in `initializeprob` will be used without modification. `value_provider`
refers to a value provider as defined by SymbolicIndexingInterface.jl.
Usually this will refer to a problem or integrator.
"""
update_initializeprob!::UIProb
"""
Expand All @@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
the parameter object of the original problem.
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
initialized will be returned as-is.
"""
initializeprobpmap::IProbPmap

Expand All @@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
end
end

"""
get_initial_values(prob, valp, f, alg, isinplace; kwargs...)

Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
and a boolean indicating whether the initialization process was successful. Keyword
arguments to this function are dependent on the initialization algorithm. `prob` is only
required for dispatching. `valp` refers the appropriate data structure from which the
current state and parameter values should be obtained. `valp` is a non-timeseries value
provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
"""
function get_initial_values end

struct CheckInitFailureError <: Exception
normresid::Any
abstol::Any
end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io,
"CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)")
end

struct OverrideInitMissingAlgorithm <: Exception end

function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
print(io,
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
tmp = first(get_tmp_cache(integrator))
f(tmp, args...)
return tmp
end

function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

"""
$(TYPEDSIGNATURES)

A utility function equivalent to `Base.vec` but also handles `Number` and
`AbstractSciMLScalarOperator`.
"""
_vec(v) = vec(v)
_vec(v::Number) = v
_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v
_vec(v::AbstractVector) = v

"""
$(TYPEDSIGNATURES)

Check if the algebraic constraints are satisfied, and error if they aren't. Returns
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
"""
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)
M = f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
Comment on lines +110 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use a cached one from the integrator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cached what, exactly? 😅 I copied almost all of this from the existing implementation

(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allocates

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow changing this to tmp = algebraic_eqs .* _vec(tmp) allocates more

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least, according to @allocated/@allocations


normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
return u0, p, true
end

"""
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
tmp = get_tmp_cache(integrator)[2]
f(tmp, args...)
return tmp
end

function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
return u0, p, true
end

"""
$(TYPEDSIGNATURES)

Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
argument, failing which this function will throw an error. The success value returned
depends on the success of the nonlinear solve.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

if !has_initialization_data(f)
return u0, p, true
end

initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

nlsol = solve(initprob, nlsolve_alg)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
end
3 changes: 2 additions & 1 deletion src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ function get_save_idxs_and_saved_subsystem(prob, save_idxs)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
elseif !(save_idxs isa AbstractArray) ||
symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
Expand Down
156 changes: 156 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test

@testset "CheckInit" begin
@testset "ODEProblem" begin
function rhs(u, p, t)
return [u[1] * t, u[1]^2 - u[2]^2]
end
function rhs!(du, u, p, t)
du[1] = u[1] * t
du[2] = u[1]^2 - u[2]^2
end

oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob)
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
end
end

@testset "DAEProblem" begin
function daerhs(du, u, p, t)
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
end
function daerhs!(resid, du, u, p, t)
resid[1] = du[1] - u[1] * t - p
resid[2] = u[1]^2 - u[2]^2
end

oopfn = DAEFunction{false}(daerhs)
iipfn = DAEFunction{true}(daerhs!)

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
integ = init(prob, DImplicitEuler())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))

integ.u[2] = 1.0
integ.du[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
end
end
end

@testset "OverrideInit" begin
function rhs2(u, p, t)
return [u[1] * t + p, u[1]^2 - u[2]^2]
end

@testset "No-op without `initialization_data`" begin
prob = ODEProblem(rhs2, [1.0, 2.0], (0.0, 1.0), 1.0)
integ = init(prob)
integ.u[2] = 3.0
u0, p, success = SciMLBase.get_initial_values(
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(false))
@test u0 ≈ [1.0, 3.0]
@test success
end

# unknowns are u[2], p. Parameter is u[1]
initprob = NonlinearProblem([1.0, 1.0], [1.0]) do x, _u1
u2, p = x
u1 = _u1[1]
return [u1^2 - u2^2, p^2 - 2p + 1]
end
update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (nlsol)
return nlsol.u[2]
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn = ODEFunction(rhs2; initialization_data)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

@testset "Errors without `nlsolve_alg`" begin
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
end

@testset "Solves" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 ≈ [2.0, 2.0]
@test p ≈ 1.0
@test success

initprob.p[1] = 1.0
end

@testset "Solves with non-integrator value provider" begin
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
u0, p, success = SciMLBase.get_initial_values(
prob, _integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 ≈ [2.0, 2.0]
@test p ≈ 1.0
@test success

initprob.p[1] = 1.0
end

@testset "Solves without `update_initializeprob!`" begin
initdata = SciMLBase.@set initialization_data.update_initializeprob! = nothing
fn = ODEFunction(rhs2; initialization_data = initdata)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
@test u0 ≈ [1.0, 1.0]
@test p ≈ 1.0
@test success
end

@testset "Solves without `initializeprobpmap`" begin
initdata = SciMLBase.@set initialization_data.initializeprobpmap = nothing
fn = ODEFunction(rhs2; initialization_data = initdata)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 ≈ [2.0, 2.0]
@test p ≈ 0.0
@test success
end
end
Loading
Loading