From 7aa84b2fde761a3a25e38032d55119ad287f3f66 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 10 Oct 2024 16:34:19 -0400 Subject: [PATCH 1/2] add W_prototype to SplitFunction --- src/scimlfunctions.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 0b41d59d9..c60cae3ff 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -518,7 +518,7 @@ numerically-defined functions. See `ModelingToolkit.SplitODEProblem` for information on generating the SplitFunction from this symbolic engine. """ struct SplitFunction{ - iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, + iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt, TPJ, O, TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip} f1::F1 @@ -531,6 +531,7 @@ struct SplitFunction{ jvp::JVP vjp::VJP jac_prototype::JP + W_prototype::WP sparsity::SP Wfact::TW Wfact_t::TWt @@ -1813,9 +1814,9 @@ OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); ## Positional Arguments -- `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective, +- `f(u,p)`: the function to optimize. `u` are the optimization variables and `p` are fixed parameters or data used in the objective, even if no such parameters are used in the objective it should be an argument in the function. For minibatching `p` can be used to pass in -a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it. +a minibatch, take a look at the tutorial [here](https://docs.sciml.ai/Optimization/stable/tutorials/minibatch/) to see how to do it. This should return a scalar, the loss value, as the return output. - `adtype`: see the Defining Optimization Functions via AD section below. @@ -2649,7 +2650,7 @@ function NonlinearFunction{iip}(f::ODEFunction) where {iip} end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, - vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, + vjp, jac_prototype, W__prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) f1 = ODEFunction(f1) @@ -2663,13 +2664,13 @@ end SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), - typeof(vjp), typeof(jac_prototype), typeof(sparsity), + typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap), typeof(initializeprobpmap)}( f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, vjp, - jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, + jac_prototype, W__prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) end function SplitFunction{iip, specialize}(f1, f2; @@ -2685,6 +2686,9 @@ function SplitFunction{iip, specialize}(f1, f2; jac_prototype = __has_jac_prototype(f1) ? f1.jac_prototype : nothing, + W_prototype = __has_W_prototype(f1) ? + f1.W_prototype : + nothing, sparsity = __has_sparsity(f1) ? f1.sparsity : jac_prototype, Wfact = __has_Wfact(f1) ? f1.Wfact : nothing, @@ -2713,10 +2717,10 @@ function SplitFunction{iip, specialize}(f1, f2; if specialize === NoSpecialize SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any, - Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache, analytic, - tgrad, jac, jvp, vjp, jac_prototype, + tgrad, jac, jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap, initializeprobpmap, initializeprobpmap) @@ -2724,14 +2728,14 @@ function SplitFunction{iip, specialize}(f1, f2; SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix), typeof(_func_cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), - typeof(jac_prototype), typeof(sparsity), + typeof(jac_prototype), typeof(W_prototype), typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec), typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap), typeof(initializeprobpmap)}(f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, - jvp, vjp, jac_prototype, + jvp, vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) end From bdce135a115f382a0563ddc426a50b0673684e04 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Fri, 11 Oct 2024 10:47:27 -0400 Subject: [PATCH 2/2] typo --- src/scimlfunctions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c60cae3ff..e783c5277 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2650,7 +2650,7 @@ function NonlinearFunction{iip}(f::ODEFunction) where {iip} end @add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp, - vjp, jac_prototype, W__prototype, sparsity, Wfact, Wfact_t, paramjac, + vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap) f1 = ODEFunction(f1)