Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nlprob to ODEFunction #800

Merged
merged 7 commits into from
Oct 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp`
where the nonlinear parameters are the tuple `(t, u_tmp, p)`.
This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t`
such that solving this function produces a solution to the implicit step of your solver.

## iip: In-Place vs Out-Of-Place

Expand Down Expand Up @@ -401,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -423,6 +427,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlprob::NLP
end

@doc doc"""
Expand Down Expand Up @@ -525,8 +530,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -549,6 +554,7 @@ struct SplitFunction{
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlprob::NLP
end

@doc doc"""
Expand Down Expand Up @@ -2426,7 +2432,8 @@ function ODEFunction{iip, specialize}(f;
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
) where {iip,
specialize
}
Expand Down Expand Up @@ -2484,11 +2491,11 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlprob)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2497,13 +2504,16 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
typeof(sys), typeof(initializeprob),
typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap),
typeof(nlprob)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
observed, _colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap, nlprob)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2514,11 +2524,12 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobpmap),
typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlprob)
end
end

Expand All @@ -2535,13 +2546,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any, Any, Any}(
typeof(f.sys), Any, Any, Any, Any, Any}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.update_initializeprob!, f.initializeprobmap,
f.initializeprobpmap)
f.initializeprobpmap, f.nlprob)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
Expand All @@ -2551,11 +2562,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
typeof(f.initializeprobmap),
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.initializeprobpmap),
typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
f.initializeprobmap, f.initializeprobpmap)
f.initializeprobmap, f.initializeprobpmap, f.nlprob)
end
end

Expand Down Expand Up @@ -2658,7 +2670,7 @@ end
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)
initializeprobmap, initializeprobpmap, nlprob)
f1 = ODEFunction(f1)
f2 = ODEFunction(f2)

Expand All @@ -2673,11 +2685,11 @@ end
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
typeof(initializeprobpmap)}(
typeof(initializeprobpmap), typeof(nlprob)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
end
function SplitFunction{iip, specialize}(f1, f2;
mass_matrix = __has_mass_matrix(f1) ?
Expand Down Expand Up @@ -2713,7 +2725,8 @@ function SplitFunction{iip, specialize}(f1, f2;
update_initializeprob! = __has_update_initializeprob!(f1) ?
f1.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing
) where {iip,
specialize
}
Expand All @@ -2724,12 +2737,12 @@ function SplitFunction{iip, specialize}(f1, f2;
if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
initializeprobpmap, initializeprobpmap)
initializeprobpmap, initializeprobpmap, nlprob)
else
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
typeof(_func_cache), typeof(analytic),
Expand All @@ -2739,11 +2752,11 @@ function SplitFunction{iip, specialize}(f1, f2;
typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(f1, f2,
typeof(initializeprobpmap), typeof(nlprob)}(f1, f2,
mass_matrix, _func_cache, analytic, tgrad, jac,
jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
end
end

Expand Down Expand Up @@ -3086,7 +3099,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f

@add_kwonly function SplitSDEFunction(f1, f2, g, mass_matrix, cache, analytic, tgrad, jac,
jvp, vjp,
jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed,
jac_prototype, Wfact, Wfact_t, paramjac, observed,
colorvec, sys)
f1 = f1 isa AbstractSciMLOperator ? f1 : SDEFunction(f1)
f2 = SDEFunction(f2)
Expand All @@ -3097,7 +3110,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(colorvec),
typeof(sys)}(f1, f2, mass_matrix, cache, analytic, tgrad, jac,
jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys)
jac_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys)
end

function SplitSDEFunction{iip, specialize}(f1, f2, g;
Expand Down Expand Up @@ -4376,6 +4389,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
__has_nlprob(f) = isdefined(f, :nlprob)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand Down
Loading