Skip to content

Commit

Permalink
add nlfunc to ODEFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Oct 14, 2024
1 parent 06864fd commit f4d8e36
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlfunc`: a `NonlinearFunction`
## iip: In-Place vs Out-Of-Place
Expand Down Expand Up @@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
O, TCV, SYS, IProb, IProbMap,
NLF} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -421,6 +422,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
nlfunc::NLF
end

@doc doc"""
Expand Down Expand Up @@ -517,8 +519,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS, IProb, IProbMap,
NLF} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -538,6 +540,7 @@ struct SplitFunction{
sys::SYS
initializeprob::IProb
initializeprobmap::IProbMap
nlfunc::NLF
end

@doc doc"""
Expand Down Expand Up @@ -2415,7 +2418,8 @@ function ODEFunction{iip, specialize}(f;
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing,
) where {iip,
specialize
}
Expand Down Expand Up @@ -2471,12 +2475,13 @@ function ODEFunction{iip, specialize}(f;
Any, Any, Any, Any,
Any, Any, Any, typeof(jac_prototype),
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
Any,typeof(_colorvec),
typeof(sys), Any, Any,
Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap,
nlfunc)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2486,10 +2491,12 @@ function ODEFunction{iip, specialize}(f;
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobmap),
typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap,
nlfun)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2499,10 +2506,12 @@ function ODEFunction{iip, specialize}(f;
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobmap),
typeof(nlfunc))}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, initializeprobmap)
observed, _colorvec, sys, initializeprob, initializeprobmap,
nlfunc)
end
end

Expand All @@ -2519,10 +2528,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.sys), Any, Any,
Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
f.nlfunc)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
Expand All @@ -2531,11 +2542,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob),
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.initializeprobmap),
typeof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.initializeprobmap)
f.initializeprobmap,
f.nlfunc)
end
end

Expand Down Expand Up @@ -4336,6 +4349,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_initializeprob(f) = isdefined(f, :initializeprob)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_nlfunc(f) = isdefined(f, :nl_func)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand Down

0 comments on commit f4d8e36

Please sign in to comment.