Skip to content

Commit

Permalink
Merge pull request #477 from avik-pal/ap/bvp
Browse files Browse the repository at this point in the history
Revisiting Boundary Value Problems
  • Loading branch information
ChrisRackauckas authored Sep 21, 2023
2 parents eb92995 + 593e7f0 commit 434e752
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 125 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "1.98.1"
version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
113 changes: 64 additions & 49 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ $(TYPEDEF)
"""
struct StandardBVProblem end

"""
$(TYPEDEF)
"""
struct TwoPointBVProblem end

@doc doc"""
Defines an BVP problem.
Expand All @@ -17,7 +22,7 @@ condition ``u_0`` which define an ODE:
\frac{du}{dt} = f(u,p,t)
```
along with an implicit function `bc!` which defines the residual equation, where
along with an implicit function `bc` which defines the residual equation, where
```math
bc(u,p,t) = 0
Expand All @@ -36,22 +41,27 @@ u(t_f) = b
### Constructors
```julia
TwoPointBVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc!,u0,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,u0,tspan,p=NullParameters();kwargs...)
```
or if we have an initial guess function `initialGuess(t)` for the given BVP,
we can pass the initial guess to the problem constructors:
```julia
TwoPointBVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc!,initialGuess,tspan,p=NullParameters();kwargs...)
TwoPointBVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
BVProblem{isinplace}(f,bc,initialGuess,tspan,p=NullParameters();kwargs...)
```
For any BVP problem type, `bc!` is the inplace function:
For any BVP problem type, `bc` must be inplace if `f` is inplace. Otherwise it must be
out-of-place.
If the bvp is a StandardBVProblem (also known as a Multi-Point BV Problem) it must define
either of the following functions
```julia
bc!(residual, u, p, t)
residual = bc(u, p, t)
```
where `residual` computed from the current `u`. `u` is an array of solution values
Expand All @@ -61,6 +71,16 @@ time points, and for shooting type methods `u=sol` the ODE solution.
Note that all features of the `ODESolution` are present in this form.
In both cases, the size of the residual matches the size of the initial condition.
If the bvp is a TwoPointBVProblem it must define either of the following functions
```julia
bc!((resid_a, resid_b), (u_a, u_b), p)
resid_a, resid_b = bc((u_a, u_b), p)
```
where `resid_a` and `resid_b` are the residuals at the two endpoints, `u_a` and `u_b` are
the solution values at the two endpoints, and `p` are the parameters.
Parameters are optional, and if not given, then a `NullParameters()` singleton
will be used which will throw nice errors if you try to index non-existent
parameters. Any extra keyword arguments are passed on to the solvers. For example,
Expand Down Expand Up @@ -88,16 +108,20 @@ struct BVProblem{uType, tType, isinplace, P, F, BF, PT, K} <:
problem_type::PT
kwargs::K

@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip}, bc, u0, tspan,
p = NullParameters(),
problem_type = StandardBVProblem();
kwargs...) where {iip}
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, bc, u0, tspan,
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p),
typeof(f), typeof(f.bc),
typeof(problem_type), typeof(kwargs)}(f, f.bc, u0, _tspan, p,
problem_type, kwargs)
prob_type = TP ? TwoPointBVProblem() : StandardBVProblem()
# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
else
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
end
return new{typeof(u0), typeof(_tspan), iip, typeof(p), typeof(f), typeof(bc),
typeof(problem_type), typeof(kwargs)}(f, bc, u0, _tspan, p, problem_type,
kwargs)
end

function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
Expand All @@ -107,52 +131,43 @@ end

TruncatedStacktraces.@truncate_stacktrace BVProblem 3 1 2

function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
return BVProblem{iip}(BVPFunction{iip}(f, bc), bc, u0, tspan, p; kwargs...)
end

function BVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
BVProblem(BVPFunction(f, bc), u0, tspan, p; kwargs...)
function BVProblem(f::AbstractBVPFunction, u0, tspan, p = NullParameters(); kwargs...)
return BVProblem{isinplace(f)}(f, f.bc, u0, tspan, p; kwargs...)
end

"""
$(TYPEDEF)
"""
struct TwoPointBVPFunction{bF}
bc::bF
# This is mostly a fake stuct and isn't used anywhere
# But we need it for function calls like TwoPointBVProblem{iip}(...) = ...
struct TwoPointBVPFunction{iip} end

@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true)
@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip}
return BVPFunction{iip}(args...; kwargs..., twopoint=true)
end
TwoPointBVPFunction(; bc = error("No argument bc")) = TwoPointBVPFunction(bc)
(f::TwoPointBVPFunction)(residual, ua, ub, p) = f.bc(residual, ua, ub, p)
(f::TwoPointBVPFunction)(residual, u, p) = f.bc(residual, u[1], u[end], p)

"""
$(TYPEDEF)
"""
struct TwoPointBVProblem{iip} end
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...)
iip = isinplace(f, 4)
TwoPointBVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...)
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
kwargs...)
end
function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();
kwargs...) where {iip}
BVProblem{iip}(f, TwoPointBVPFunction(bc), u0, tspan, p; kwargs...)
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)
end

# Allow previous timeseries solution
function TwoPointBVProblem(f::AbstractODEFunction,
bc,
sol::T,
tspan::Tuple,
p = NullParameters()) where {T <: AbstractTimeseriesSolution}
TwoPointBVProblem(f, bc, sol.u, tspan, p)
function TwoPointBVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple,
p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution}
return TwoPointBVProblem(f, bc, sol.u, tspan, p; kwargs...)
end
# Allow initial guess function for the initial guess
function TwoPointBVProblem(f::AbstractODEFunction,
bc,
initialGuess,
tspan::AbstractVector,
p = NullParameters();
kwargs...)
function TwoPointBVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
p = NullParameters(); kwargs...)
u0 = [initialGuess(i) for i in tspan]
TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p)
return TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p; kwargs...)
end
104 changes: 45 additions & 59 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2124,8 +2124,7 @@ TruncatedStacktraces.@truncate_stacktrace OptimizationFunction 1 2
"""
$(TYPEDEF)
"""
abstract type AbstractBVPFunction{iip} <:
AbstractDiffEqFunction{iip} end
abstract type AbstractBVPFunction{iip, twopoint} <: AbstractDiffEqFunction{iip} end

@doc doc"""
BVPFunction{iip,F,BF,TMM,Ta,Tt,TJ,BCTJ,JVP,VJP,JP,BCJP,SP,TW,TWt,TPJ,S,S2,S3,O,TCV,BCTCV} <: AbstractBVPFunction{iip,specialize}
Expand Down Expand Up @@ -2230,11 +2229,9 @@ For more details on this argument, see the ODEFunction documentation.
The fields of the BVPFunction type directly match the names of the inputs.
"""
struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP,
BCJP, SP, TW, TWt,
TPJ,
S, S2, S3, O, TCV, BCTCV,
SYS} <: AbstractBVPFunction{iip}
struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
JP, BCJP, BCRP, SP, TW, TWt, TPJ, S, S2, S3, O, TCV, BCTCV,
SYS} <: AbstractBVPFunction{iip, twopoint}
f::F
bc::BF
mass_matrix::TMM
Expand All @@ -2246,6 +2243,7 @@ struct BVPFunction{iip, specialize, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP,
vjp::VJP
jac_prototype::JP
bcjac_prototype::BCJP
bcresid_prototype::BCRP
sparsity::SP
Wfact::TW
Wfact_t::TWt
Expand Down Expand Up @@ -3648,9 +3646,8 @@ function NonlinearFunction{iip, specialize}(f;
nothing,
sys = __has_sys(f) ? f.sys : nothing,
resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing) where {
iip,
specialize,
}
iip, specialize}

if mass_matrix === I && typeof(f) <: Tuple
mass_matrix = ((I for i in 1:length(f))...,)
end
Expand Down Expand Up @@ -3814,35 +3811,28 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
cons_expr, sys)
end

function BVPFunction{iip, specialize}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
I,
function BVPFunction{iip, specialize, twopoint}(f, bc;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
bcjac = __has_jac(bc) ? bc.jac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ?
f.jac_prototype :
nothing,
bcjac_prototype = __has_jac_prototype(bc) ?
bc.jac_prototype :
nothing,
sparsity = __has_sparsity(f) ? f.sparsity :
jac_prototype,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
bcjac_prototype = __has_jac_prototype(bc) ? bc.jac_prototype : nothing,
bcresid_prototype = nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = __has_syms(f) ? f.syms : nothing,
indepsym = __has_indepsym(f) ? f.indepsym : nothing,
paramsyms = __has_paramsyms(f) ? f.paramsyms :
nothing,
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
paramsyms = __has_paramsyms(f) ? f.paramsyms : nothing,
observed = __has_observed(f) ? f.observed : DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize}
sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint}
if mass_matrix === I && typeof(f) <: Tuple
mass_matrix = ((I for i in 1:length(f))...,)
end
Expand Down Expand Up @@ -3882,7 +3872,7 @@ function BVPFunction{iip, specialize}(f, bc;
_bccolorvec = bccolorvec
end

bciip = isinplace(bc, 4, "bc", iip)
bciip = !twopoint ? isinplace(bc, 4, "bc", iip) : isinplace(bc, 3, "bc", iip)
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
bcjaciip = bcjac !== nothing ? isinplace(bcjac, 4, "bcjac", bciip) : bciip
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
Expand All @@ -3892,66 +3882,62 @@ function BVPFunction{iip, specialize}(f, bc;
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip

nonconforming = (jaciip,
tgradiip,
jvpiip,
vjpiip,
Wfactiip,
Wfact_tiip,
nonconforming = (bciip, jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
paramjaciip) .!= iip
bc_nonconforming = bcjaciip .!= bciip
if any(nonconforming)
nonconforming = findall(nonconforming)
functions = ["jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
functions = ["bc", "jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
"paramjac"][nonconforming]
throw(NonconformingFunctionsError(functions))
end

if twopoint
if iip && (bcresid_prototype === nothing || length(bcresid_prototype) != 2)
error("bcresid_prototype must be a tuple / indexable collection of length 2 for a inplace TwoPointBVPFunction")
end
if bcresid_prototype !== nothing && length(bcresid_prototype) == 2
bcresid_prototype = ArrayPartition(bcresid_prototype[1], bcresid_prototype[2])
end
end

if any(bc_nonconforming)
bc_nonconforming = findall(bc_nonconforming)
functions = ["bcjac"][bc_nonconforming]
throw(NonconformingFunctionsError(functions))
end

if specialize === NoSpecialize
BVPFunction{iip, specialize, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any, Any,
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, typeof(syms), typeof(indepsym), typeof(paramsyms),
Any, typeof(_colorvec), typeof(_bccolorvec), Any}(f, bc, mass_matrix,
analytic,
tgrad,
jac, bcjac, jvp, vjp,
jac_prototype,
bcjac_prototype,
sparsity, Wfact,
Wfact_t,
paramjac, syms,
indepsym, paramsyms,
observed,
analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype,
bcjac_prototype, bcresid_prototype,
sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed,
_colorvec, _bccolorvec, sys)
else
BVPFunction{iip, specialize, typeof(f), typeof(bc), typeof(mass_matrix),
typeof(analytic),
typeof(tgrad),
typeof(jac), typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(bcjac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms),
typeof(observed),
BVPFunction{iip, specialize, twopoint, typeof(f), typeof(bc), typeof(mass_matrix),
typeof(analytic), typeof(tgrad), typeof(jac), typeof(bcjac), typeof(jvp),
typeof(vjp), typeof(jac_prototype),
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms),
typeof(indepsym), typeof(paramsyms), typeof(observed),
typeof(_colorvec), typeof(_bccolorvec), typeof(sys)}(f, bc, mass_matrix, analytic,
tgrad, jac, bcjac, jvp, vjp,
jac_prototype, bcjac_prototype, sparsity,
jac_prototype, bcjac_prototype, bcresid_prototype, sparsity,
Wfact, Wfact_t, paramjac,
syms, indepsym, paramsyms, observed,
_colorvec, _bccolorvec, sys)
end
end

function BVPFunction{iip}(f, bc; kwargs...) where {iip}
BVPFunction{iip, FullSpecialize}(f, bc; kwargs...)
function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...)
end
BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f
function BVPFunction(f, bc; kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize}(f, bc; kwargs...)
function BVPFunction(f, bc; twopoint::Bool=false, kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

Expand Down
2 changes: 1 addition & 1 deletion test/downstream/ensemble_bvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ tspan = (0.0, pi / 2)
p = [rand()]
bvp = BVProblem(ode!, bc!, initial_guess, tspan, p)
ensemble_prob = EnsembleProblem(bvp, prob_func = prob_func)
sim = solve(ensemble_prob, GeneralMIRK4(), trajectories = 10, dt = 0.1)
sim = solve(ensemble_prob, MIRK4(), trajectories = 10, dt = 0.1)
Loading

0 comments on commit 434e752

Please sign in to comment.