Skip to content

Commit

Permalink
Merge pull request #845 from AayushSabharwal/as/move-initalgs
Browse files Browse the repository at this point in the history
feat: add implementations of `CheckInit` and `OverrideInit`
  • Loading branch information
ChrisRackauckas authored Nov 7, 2024
2 parents 2c4dfc0 + b3105cc commit cf142a8
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 5 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ LinearAlgebra = "1.10"
Logging = "1.10"
Makie = "0.20, 0.21"
Markdown = "1.10"
NonlinearSolve = "3, 4"
PartialFunctions = "1.1"
PrecompileTools = "1.2"
Preferences = "1.3"
Expand Down Expand Up @@ -98,6 +99,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -114,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "NonlinearSolve", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
7 changes: 6 additions & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ $(TYPEDEF)
"""
struct CheckInit <: DAEInitializationAlgorithm end

"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end

# PDE Discretizations

"""
Expand Down Expand Up @@ -654,7 +659,6 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
struct TrackerOriginator <: ADOriginator end

include("utils.jl")
include("initialization.jl")
include("function_wrappers.jl")
include("scimlfunctions.jl")
include("alg_traits.jl")
Expand Down Expand Up @@ -740,6 +744,7 @@ include("ensemble/ensemble_problems.jl")
include("ensemble/basic_ensemble_solve.jl")
include("ensemble/ensemble_analysis.jl")

include("initialization.jl")
include("solve.jl")
include("interpolation.jl")
include("integrator_interface.jl")
Expand Down
162 changes: 160 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
initializeprob::IProb
"""
A function which takes `(initializeprob, prob)` and updates
A function which takes `(initializeprob, value_provider)` and updates
the parameters of the former with their values in the latter.
If absent (`nothing`) this will not be called, and the parameters
in `initializeprob` will be used without modification. `value_provider`
refers to a value provider as defined by SymbolicIndexingInterface.jl.
Usually this will refer to a problem or integrator.
"""
update_initializeprob!::UIProb
"""
Expand All @@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
the parameter object of the original problem.
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
initialized will be returned as-is.
"""
initializeprobpmap::IProbPmap

Expand All @@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
end
end

"""
get_initial_values(prob, valp, f, alg, isinplace; kwargs...)
Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
and a boolean indicating whether the initialization process was successful. Keyword
arguments to this function are dependent on the initialization algorithm. `prob` is only
required for dispatching. `valp` refers the appropriate data structure from which the
current state and parameter values should be obtained. `valp` is a non-timeseries value
provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
"""
function get_initial_values end

struct CheckInitFailureError <: Exception
normresid::Any
abstol::Any
end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io,
"CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)")
end

struct OverrideInitMissingAlgorithm <: Exception end

function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
print(io,
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
tmp = first(get_tmp_cache(integrator))
f(tmp, args...)
return tmp
end

function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

"""
$(TYPEDSIGNATURES)
A utility function equivalent to `Base.vec` but also handles `Number` and
`AbstractSciMLScalarOperator`.
"""
_vec(v) = vec(v)
_vec(v::Number) = v
_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v
_vec(v::AbstractVector) = v

"""
$(TYPEDSIGNATURES)
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
"""
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)
M = f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
return u0, p, true
end

"""
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
tmp = get_tmp_cache(integrator)[2]
f(tmp, args...)
return tmp
end

function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
return u0, p, true
end

"""
$(TYPEDSIGNATURES)
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
argument, failing which this function will throw an error. The success value returned
depends on the success of the nonlinear solve.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

if !has_initialization_data(f)
return u0, p, true
end

initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

nlsol = solve(initprob, nlsolve_alg)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
end
3 changes: 2 additions & 1 deletion src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ function get_save_idxs_and_saved_subsystem(prob, save_idxs)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
elseif !(save_idxs isa AbstractArray) ||
symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
Expand Down
156 changes: 156 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test

@testset "CheckInit" begin
@testset "ODEProblem" begin
function rhs(u, p, t)
return [u[1] * t, u[1]^2 - u[2]^2]
end
function rhs!(du, u, p, t)
du[1] = u[1] * t
du[2] = u[1]^2 - u[2]^2
end

oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob)
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
end
end

@testset "DAEProblem" begin
function daerhs(du, u, p, t)
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
end
function daerhs!(resid, du, u, p, t)
resid[1] = du[1] - u[1] * t - p
resid[2] = u[1]^2 - u[2]^2
end

oopfn = DAEFunction{false}(daerhs)
iipfn = DAEFunction{true}(daerhs!)

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
integ = init(prob, DImplicitEuler())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))

integ.u[2] = 1.0
integ.du[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
end
end
end

@testset "OverrideInit" begin
function rhs2(u, p, t)
return [u[1] * t + p, u[1]^2 - u[2]^2]
end

@testset "No-op without `initialization_data`" begin
prob = ODEProblem(rhs2, [1.0, 2.0], (0.0, 1.0), 1.0)
integ = init(prob)
integ.u[2] = 3.0
u0, p, success = SciMLBase.get_initial_values(
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(false))
@test u0 [1.0, 3.0]
@test success
end

# unknowns are u[2], p. Parameter is u[1]
initprob = NonlinearProblem([1.0, 1.0], [1.0]) do x, _u1
u2, p = x
u1 = _u1[1]
return [u1^2 - u2^2, p^2 - 2p + 1]
end
update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (nlsol)
return nlsol.u[2]
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn = ODEFunction(rhs2; initialization_data)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

@testset "Errors without `nlsolve_alg`" begin
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
end

@testset "Solves" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 [2.0, 2.0]
@test p 1.0
@test success

initprob.p[1] = 1.0
end

@testset "Solves with non-integrator value provider" begin
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
u0, p, success = SciMLBase.get_initial_values(
prob, _integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 [2.0, 2.0]
@test p 1.0
@test success

initprob.p[1] = 1.0
end

@testset "Solves without `update_initializeprob!`" begin
initdata = SciMLBase.@set initialization_data.update_initializeprob! = nothing
fn = ODEFunction(rhs2; initialization_data = initdata)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
@test u0 [1.0, 1.0]
@test p 1.0
@test success
end

@testset "Solves without `initializeprobpmap`" begin
initdata = SciMLBase.@set initialization_data.initializeprobpmap = nothing
fn = ODEFunction(rhs2; initialization_data = initdata)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 [2.0, 2.0]
@test p 0.0
@test success
end
end
Loading

0 comments on commit cf142a8

Please sign in to comment.