diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index ee1dc0d79..c8c8d0b66 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -724,6 +724,7 @@ include("problems/problem_interface.jl") include("problems/optimization_problems.jl") include("clock.jl") +include("solutions/save_idxs.jl") include("solutions/basic_solutions.jl") include("solutions/nonlinear_solutions.jl") include("solutions/ode_solutions.jl") diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index a3137ea7d..9f042681f 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -331,19 +331,22 @@ Otherwise the integrator is allowed to skip recalculating the interpolation. # Arguments -- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback) - or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation - of the interpolations. -- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the - initialization that is done post callback. The default value of `nothing` means that the initialization choice - used for the DAE should be performed post-callback. + - `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback) + or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation + of the interpolations. + - `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the + initialization that is done post callback. The default value of `nothing` means that the initialization choice + used for the DAE should be performed post-callback. """ function reeval_internals_due_to_modification!( integrator::DEIntegrator, continuous_modification; callback_initializealg = nothing) reeval_internals_due_to_modification!(integrator::DEIntegrator) end -reeval_internals_due_to_modification!(integrator::DEIntegrator; callback_initializealg = nothing) = nothing +function reeval_internals_due_to_modification!( + integrator::DEIntegrator; callback_initializealg = nothing) + nothing +end """ set_t!(integrator::DEIntegrator, t) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 22981e8dd..1d9ece710 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2693,8 +2693,8 @@ function SplitFunction{iip, specialize}(f1, f2; f1.jac_prototype : nothing, W_prototype = __has_W_prototype(f1) ? - f1.W_prototype : - nothing, + f1.W_prototype : + nothing, sparsity = __has_sparsity(f1) ? f1.sparsity : jac_prototype, Wfact = __has_Wfact(f1) ? f1.Wfact : nothing, diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index a3fb1a8d1..cab8686fb 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ -struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType} <: +struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType, V} <: AbstractDAESolution{T, N, uType} u::uType du::duType @@ -42,6 +42,31 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateT tslocation::Int stats::S retcode::ReturnCode.T + saved_subsystem::V +end + +function DAESolution{T, N}(u, du, u_analytic, errors, t, k, prob, alg, interp, dense, + tslocation, stats, retcode, saved_subsystem) where {T, N} + return DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors), + typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k), + typeof(saved_subsystem)}( + u, du, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, stats, + retcode, saved_subsystem + ) +end + +function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: DAESolution{T, N}} + DAESolution{T, N} +end + +function ConstructionBase.setproperties(sol::DAESolution, patch::NamedTuple) + u = get(patch, :u, sol.u) + N = u === nothing ? 2 : ndims(eltype(u)) + 1 + T = eltype(eltype(u)) + patch = merge(getproperties(sol), patch) + return DAESolution{T, N}(patch.u, patch.du, patch.u_analytic, patch.errors, patch.t, + patch.k, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, + patch.stats, patch.retcode, patch.saved_subsystem) end Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Symbol) @@ -65,13 +90,14 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; retcode = ReturnCode.Default, destats = missing, stats = nothing, + saved_subsystem = nothing, kwargs...) T = eltype(eltype(u)) if prob.u0 === nothing N = 2 else - N = length((size(prob.u0)..., length(u))) + N = ndims(eltype(u)) + 1 end if !ismissing(destats) @@ -88,7 +114,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; errors = Dict{Symbol, real(eltype(prob.u0))}() sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors), - typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}( + typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k), + typeof(saved_subsystem)}( u, du, u_analytic, @@ -101,7 +128,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; dense, 0, stats, - retcode) + retcode, + saved_subsystem) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -110,7 +138,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; sol else DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t), - typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}( + typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k), + typeof(saved_subsystem)}( u, du, nothing, nothing, t, k, @@ -118,7 +147,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; interp, dense, 0, stats, - retcode) + retcode, + saved_subsystem) end end @@ -161,76 +191,23 @@ function calculate_solution_errors!(sol::AbstractDAESolution; end function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) where {T, N} - DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors), - typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.stats), typeof(sol.k)}(sol.u, - sol.du, - u_analytic, - errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.retcode) + @reset sol.u_analytic = u_analytic + return @set sol.errors = errors end function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N} - DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), - typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), - typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - retcode) + return @set sol.retcode = retcode end function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N} - DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), - typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), - typeof(sol.interp), typeof(sol.stats), typeof(k)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.stats, - sol.retcode) + return @set sol.tslocation = tslocation end function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N} - DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), - typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), - typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u[I], - sol.du[I], - sol.u_analytic === - nothing ? - nothing : - sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.k[I], - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.stats, - sol.retcode) + @reset sol.u = sol.u[I] + @reset sol.du = sol.du[I] + @reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I] + @reset sol.t = sol.t[I] + @reset sol.k = sol.dense ? sol.k[I] : sol.k + return @set sol.dense = false end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index bb2a9fabc..f947d5895 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ successfully, whether it terminated early due to a user-defined callback, or whether it exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). +- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved + in this solution, or `nothing` if all variables are saved. Here "variables" refers to + both continuous-time state variables and timeseries parameters. """ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S, - AC <: Union{Nothing, Vector{Int}}, R, O} <: + AC <: Union{Nothing, Vector{Int}}, R, O, V} <: AbstractODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, retcode::ReturnCode.T resid::R original::O + saved_subsystem::V end function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}} @@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple) patch = merge(getproperties(sol), patch) return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k, patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, - patch.alg_choice, patch.retcode, patch.resid, patch.original) + patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem) end Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol) @@ -154,12 +158,12 @@ end function ODESolution{T, N}( u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense, tslocation, stats, alg_choice, retcode, resid = nothing, - original = nothing) where {T, N} + original = nothing, saved_subsystem = nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp), - typeof(stats), typeof(alg_choice), typeof(resid), - typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp, - dense, tslocation, stats, alg_choice, retcode, resid, original) + typeof(stats), typeof(alg_choice), typeof(resid), typeof(original), + typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp, + dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem) end error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing @@ -409,15 +413,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac # public API, used by MTK """ get_saveable_values(sys, ps, timeseries_idx) + +Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by +`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything. """ function get_saveable_values(sys, ps, timeseries_idx) return get_saveable_values(symbolic_container(sys), ps, timeseries_idx) end +""" + save_discretes!(integ::DEIntegrator, timeseries_idx) + +Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to +get the values to save. If it returns `nothing`, then the save does not happen. +""" function save_discretes!(integ::DEIntegrator, timeseries_idx) - save_discretes!(integ.sol, current_time(integ), - get_saveable_values(integ, parameter_values(integ), timeseries_idx), - timeseries_idx) + inner_sol = get_sol(integ) + vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx) + vals === nothing && return + save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx) end save_discretes!(args...) = nothing @@ -451,6 +465,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, destats = missing, stats = nothing, resid = nothing, original = nothing, + saved_subsystem = nothing, kwargs...) T = eltype(eltype(u)) @@ -482,7 +497,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, ps = parameter_values(prob) if has_sys(prob.f) - discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan) + sswf = if saved_subsystem === nothing + prob.f.sys + else + SavedSubsystemWithFallback(saved_subsystem, prob.f.sys) + end + discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan) else discretes = nothing end @@ -503,7 +523,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, alg_choice, retcode, resid, - original) + original, + saved_subsystem) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, dense_errors = dense_errors) @@ -524,7 +545,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, alg_choice, retcode, resid, - original) + original, + saved_subsystem) end end @@ -593,7 +615,7 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N} @reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I] @reset sol.t = sol.t[I] @reset sol.k = sol.dense ? sol.k[I] : sol.k - return @set sol.alg = false + return @set sol.dense = false end mask_discretes(::Nothing, _, _...) = nothing diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index 340cde11a..2bf7d5084 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -33,7 +33,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S, - AC <: Union{Nothing, Vector{Int}}} <: + AC <: Union{Nothing, Vector{Int}}, V} <: AbstractRODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -49,6 +49,7 @@ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S, alg_choice::AC retcode::ReturnCode.T seed::UInt64 + saved_subsystem::V end function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: RODESolution{T, N}} @@ -63,10 +64,10 @@ function ConstructionBase.setproperties(sol::RODESolution, patch::NamedTuple) return RODESolution{ T, N, typeof(patch.u), typeof(patch.u_analytic), typeof(patch.errors), typeof(patch.t), typeof(patch.W), typeof(patch.prob), typeof(patch.alg), typeof(patch.interp), - typeof(patch.stats), typeof(patch.alg_choice)}( + typeof(patch.stats), typeof(patch.alg_choice), typeof(patch.saved_subsystem)}( patch.u, patch.u_analytic, patch.errors, patch.t, patch.W, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, - patch.alg_choice, patch.retcode, patch.seed) + patch.alg_choice, patch.retcode, patch.seed, patch.saved_subsystem) end Base.@propagate_inbounds function Base.getproperty(x::AbstractRODESolution, s::Symbol) @@ -94,9 +95,14 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, alg_choice = nothing, - seed = UInt64(0), destats = missing, stats = nothing, kwargs...) + seed = UInt64(0), destats = missing, stats = nothing, + saved_subsystem = nothing, kwargs...) T = eltype(eltype(u)) - N = length((size(prob.u0)..., length(u))) + if prob.u0 === nothing + N = 2 + else + N = ndims(eltype(u)) + 1 + end if prob.f isa Tuple f = prob.f[1] @@ -120,7 +126,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(W), typeof(prob), typeof(alg), typeof(interp), typeof(stats), - typeof(alg_choice)}(u, + typeof(alg_choice), typeof(saved_subsystem)}(u, u_analytic, errors, t, W, @@ -132,7 +138,8 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, stats, alg_choice, retcode, - seed) + seed, + saved_subsystem) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -143,10 +150,11 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, else return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t), typeof(W), typeof(prob), typeof(alg), typeof(interp), - typeof(stats), typeof(alg_choice)}(u, nothing, nothing, t, W, + typeof(stats), typeof(alg_choice), typeof(saved_subsystem)}( + u, nothing, nothing, t, W, prob, alg, interp, dense, 0, stats, - alg_choice, retcode, seed) + alg_choice, retcode, seed, saved_subsystem) end end @@ -197,54 +205,24 @@ function calculate_solution_errors!(sol::AbstractRODESolution; fill_uanalytic = end end -function build_solution(sol::AbstractRODESolution{T, N}, u_analytic, errors) where {T, N} - RODESolution{T, N, typeof(sol.u), typeof(u_analytic), typeof(errors), typeof(sol.t), - typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.stats), typeof(sol.alg_choice)}(sol.u, u_analytic, errors, - sol.t, sol.W, sol.prob, - sol.alg, sol.interp, - sol.dense, sol.tslocation, - sol.stats, sol.alg_choice, - sol.retcode, sol.seed) +function build_solution(sol::AbstractRODESolution, u_analytic, errors) + @reset sol.u_analytic = u_analytic + return @set sol.errors = errors end -function solution_new_retcode(sol::AbstractRODESolution{T, N}, retcode) where {T, N} - RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), - typeof(sol.t), - typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.stats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic, - sol.errors, sol.t, sol.W, - sol.prob, sol.alg, sol.interp, - sol.dense, sol.tslocation, - sol.stats, sol.alg_choice, - retcode, sol.seed) +function solution_new_retcode(sol::AbstractRODESolution, retcode) + return @set sol.retcode = retcode end -function solution_new_tslocation(sol::AbstractRODESolution{T, N}, tslocation) where {T, N} - RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), - typeof(sol.t), - typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.stats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic, - sol.errors, sol.t, sol.W, - sol.prob, sol.alg, sol.interp, - sol.dense, tslocation, - sol.stats, sol.alg_choice, - sol.retcode, sol.seed) +function solution_new_tslocation(sol::AbstractRODESolution, tslocation) + return @set sol.tslocation = tslocation end function solution_slice(sol::AbstractRODESolution{T, N}, I) where {T, N} - RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), - typeof(sol.t), - typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.stats), typeof(sol.alg_choice)}(sol.u[I], - sol.u_analytic === nothing ? - nothing : sol.u_analytic, - sol.errors, sol.t[I], - sol.W, sol.prob, - sol.alg, sol.interp, - false, sol.tslocation, - sol.stats, sol.alg_choice, - sol.retcode, sol.seed) + @reset sol.u = sol.u[I] + @reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I] + @reset sol.t = sol.t[I] + return @set sol.dense = false end function sensitivity_solution(sol::AbstractRODESolution, u, t) @@ -259,22 +237,7 @@ function sensitivity_solution(sol::AbstractRODESolution, u, t) end interp = enable_interpolation_sensitivitymode(sol.interp) - - RODESolution{T, N, typeof(u), typeof(sol.u_analytic), - typeof(sol.errors), typeof(t), - typeof(nothing), typeof(sol.prob), typeof(sol.alg), - typeof(sol.interp), typeof(sol.stats), typeof(sol.alg_choice)}(u, - sol.u_analytic, - sol.errors, - t, - nothing, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.seed) + @reset sol.u = u + @reset sol.t = t isa Vector ? t : collect(t) + return @set sol.interp = interp end diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl new file mode 100644 index 000000000..f6839a015 --- /dev/null +++ b/src/solutions/save_idxs.jl @@ -0,0 +1,322 @@ +#= +(Symbolic) save_idxs interface: + +Allows symbolically indexing solutions where a subset of the variables are saved. + +To implement this interface, the solution must store a `SavedSubsystem` if it contains a +subset of all timeseries variables. The `get_saved_subsystem` function must be implemented +to return the `SavedSubsystem` if present and `nothing` otherwise. + +The solution must forward `is_timeseries_parameter`, `timeseries_parameter_index`, +`with_updated_parameter_timeseries_values` and `get_saveable_values` to +`SavedSubsystemWithFallback(sol.saved_subsystem, symbolic_container(sol))`. + +Additionally, it must implement `state_values` to always return the full state vector, using +the `SciMLProblem`'s `u0` as a reference, and updating the saved values in it. + +See the implementation for `ODESolution` as a reference. +=# + +get_saved_subsystem(_) = nothing + +struct VectorTemplate + type::DataType + size::Int +end + +struct TupleOfArraysWrapper{T} + x::T +end + +function TupleOfArraysWrapper(vt::Vector{VectorTemplate}) + return TupleOfArraysWrapper(Tuple(map(t -> Vector{t.type}(undef, t.size), vt))) +end + +function Base.getindex(t::TupleOfArraysWrapper, i::Tuple{Int, Int}) + t.x[i[1]][i[2]] +end + +function Base.setindex!(t::TupleOfArraysWrapper, val, i::Tuple{Int, Int}) + t.x[i[1]][i[2]] = val +end + +function as_diffeq_array(vt::Vector{VectorTemplate}, t) + return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1)) +end + +""" + $(TYPEDSIGNATURES) + +A representation of the subsystem of a given system which is saved in a solution. Created +by providing an index provider and the indexes of saved variables in the system. The indexes +can also be symbolic variables. All indexes must refer to state variables, or timeseries +parameters. + +The arguments to the constructor are an index provider, the parameter object and the indexes +of variables to save. + +This object is stored in the solution object and used for symbolic indexing of the subsetted +solution. + +In case the provided `saved_idxs` is `nothing` or `isempty`, or if the provided +`saved_idxs` includes all of the variables and timeseries parameters, returns `nothing`. +""" +struct SavedSubsystem{V, T, M, I, P, Q, C} + """ + `Dict` mapping indexes of saved variables in the parent system to corresponding + indexes in the saved continuous timeseries. + """ + state_map::V + """ + `Dict` mapping indexes of saved timeseries parameters in the parent system to + corresponding `ParameterTimeseriesIndex`es in the save parameter timeseries. + """ + timeseries_params_map::T + """ + `Dict` mapping `ParameterTimeseriesIndex`es to indexes of parameters in the + system. (`timeseries_parameter_index => parameter_index`) + """ + timeseries_idx_to_param_idx::M + """ + `Set` of all timeseries_idxs that are saved as-is. + """ + identity_partitions::I + """ + `Dict` mapping timeseries indexes to a vector of `VectorTemplate`s to use for storing + that subsetted timeseries partition + """ + timeseries_partition_templates::P + """ + `Dict` mapping timeseries indexes to a vector of `ParameterTimeseriesIndex` in that + partition. Only for saved partitions not in `identity_partitions`. + """ + indexes_in_partition::Q + """ + Map of timeseries indexes to the number of saved timeseries parameters in that + partition. + """ + partition_count::C +end + +function SavedSubsystem(indp, pobj, saved_idxs) + # nothing saved + if saved_idxs === nothing || isempty(saved_idxs) + return nothing + end + + # array state symbolics must be scalarized + saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym + if symbolic_type(sym) == NotSymbolic() + (sym,) + elseif sym isa AbstractArray && is_variable(indp, sym) + collect(sym) + else + (sym,) + end + end)) + + saved_state_idxs = Int[] + ts_idx_to_type_to_param_idx = Dict() + ts_idx_to_count = Dict() + num_ts_params = 0 + TParammapKeys = Union{} + TParamIdx = Union{} + timeseries_idx_to_param_idx = Dict() + for var in saved_idxs + if (idx = variable_index(indp, var)) !== nothing + push!(saved_state_idxs, idx) + elseif (idx = timeseries_parameter_index(indp, var)) !== nothing + TParammapKeys = Base.promote_typejoin(TParammapKeys, typeof(idx)) + # increment total number of ts params + num_ts_params += 1 + # get dict mapping type to idxs for this timeseries_idx + buf = get!(() -> Dict(), ts_idx_to_type_to_param_idx, idx.timeseries_idx) + # get type of parameter + pidx = parameter_index(indp, var) + timeseries_idx_to_param_idx[idx] = pidx + TParamIdx = Base.promote_typejoin(TParamIdx, typeof(pidx)) + val = parameter_values(pobj, pidx) + T = typeof(val) + # get vector of idxs for this type + buf = get!(() -> [], buf, T) + # push to it + push!(buf, idx) + # update count of variables in this partition + cnt = get(ts_idx_to_count, idx.timeseries_idx, 0) + ts_idx_to_count[idx.timeseries_idx] = cnt + 1 + else + throw(ArgumentError("Can only save variables and timeseries parameters. Got $var.")) + end + end + + # type of timeseries_idxs + Ttsidx = Union{} + for k in keys(ts_idx_to_type_to_param_idx) + Ttsidx = Base.promote_typejoin(Ttsidx, typeof(k)) + end + + # timeseries_idx to timeseries_parameter_index for all params + all_ts_params = Dict() + num_all_ts_params = 0 + for var in parameter_symbols(indp) + if (idx = timeseries_parameter_index(indp, var)) !== nothing + num_all_ts_params += 1 + buf = get!(() -> [], all_ts_params, idx.timeseries_idx) + push!(buf, idx) + end + end + + save_all_states = length(saved_state_idxs) == length(variable_symbols(indp)) + save_all_tsparams = num_ts_params == num_all_ts_params + if save_all_states && save_all_tsparams + # if we're saving everything + return nothing + end + if save_all_states + sort!(saved_state_idxs) + state_map = saved_state_idxs + else + state_map = Dict(saved_state_idxs .=> collect(eachindex(saved_state_idxs))) + end + + if save_all_tsparams + if isempty(ts_idx_to_type_to_param_idx) + identity_partitions = () + else + identity_partitions = Set{Ttsidx}(keys(ts_idx_to_type_to_param_idx)) + end + return SavedSubsystem( + state_map, nothing, nothing, identity_partitions, nothing, nothing, nothing) + end + + if num_ts_params == 0 + return SavedSubsystem(state_map, nothing, nothing, (), nothing, nothing, nothing) + end + + identitypartitions = Set{Ttsidx}() + parammap = Dict() + timeseries_partition_templates = Dict() + TsavedParamIdx = ParameterTimeseriesIndex{Ttsidx, NTuple{2, Int}} + indexes_in_partition = Dict{Ttsidx, Vector{TParammapKeys}}() + for (tsidx, type_to_idxs) in ts_idx_to_type_to_param_idx + if ts_idx_to_count[tsidx] == length(all_ts_params[tsidx]) + push!(identitypartitions, tsidx) + continue + end + templates = VectorTemplate[] + for (type, idxs) in type_to_idxs + template = VectorTemplate(type, length(idxs)) + push!(templates, template) + for (i, idx) in enumerate(idxs) + pti = ParameterTimeseriesIndex(tsidx, (length(templates), i)) + parammap[idx] = pti + + buf = get!(() -> TsavedParamIdx[], indexes_in_partition, tsidx) + push!(buf, idx) + end + end + timeseries_partition_templates[tsidx] = templates + end + parammap = Dict{TParammapKeys, TsavedParamIdx}(parammap) + timeseries_partition_templates = Dict{Ttsidx, Vector{VectorTemplate}}(timeseries_partition_templates) + ts_idx_to_count = Dict{Ttsidx, Int}(ts_idx_to_count) + timeseries_idx_to_param_idx = Dict{TParammapKeys, TParamIdx}(timeseries_idx_to_param_idx) + return SavedSubsystem( + state_map, parammap, timeseries_idx_to_param_idx, identitypartitions, + timeseries_partition_templates, indexes_in_partition, ts_idx_to_count) +end + +""" + $(TYPEDEF) + +A combination of a `SavedSubsystem` and a fallback index provider. The provided fallback +is used as the `symbolic_container` for the `SavedSubsystemWithFallback`. Manually +implements `is_timeseries_parameter` and `timeseries_parameter_index` using the +`SavedSubsystem` to return the appropriate indexes for the subset of saved variables, +and `nothing`/`false` otherwise. + +Also implements `create_parameter_timeseries_collection`, `get_saveable_values` and +`with_updated_parameter_timeseries_values` to appropriately handled subsetted timeseries +parameters. +""" +struct SavedSubsystemWithFallback{S <: SavedSubsystem, T} + saved_subsystem::S + fallback::T +end + +function SymbolicIndexingInterface.symbolic_container(sswf::SavedSubsystemWithFallback) + sswf.fallback +end + +function SymbolicIndexingInterface.is_timeseries_parameter( + sswf::SavedSubsystemWithFallback, sym) + timeseries_parameter_index(sswf, sym) !== nothing +end + +function SymbolicIndexingInterface.timeseries_parameter_index( + sswf::SavedSubsystemWithFallback, sym) + ss = sswf.saved_subsystem + ss.timeseries_params_map === nothing && return nothing + if symbolic_type(sym) == NotSymbolic() + sym isa ParameterTimeseriesIndex || return nothing + sym.timeseries_idx in ss.identity_partitions && return sym + return get(ss.timeseries_params_map, sym, nothing) + end + v = timeseries_parameter_index(sswf.fallback, sym) + return timeseries_parameter_index(sswf, v) +end + +function create_parameter_timeseries_collection(sswf::SavedSubsystemWithFallback, ps, tspan) + original = create_parameter_timeseries_collection(sswf.fallback, ps, tspan) + ss = sswf.saved_subsystem + if original === nothing + return nothing + end + newcollection = map(enumerate(original)) do (i, buffer) + i in ss.identity_partitions && return buffer + ss.partition_count === nothing && return buffer + cnt = get(ss.partition_count, i, 0) + cnt == 0 && return buffer + + return as_diffeq_array(ss.timeseries_partition_templates[i], buffer.t) + end + + return ParameterTimeseriesCollection(newcollection, parameter_values(original)) +end + +function get_saveable_values(sswf::SavedSubsystemWithFallback, ps, tsidx) + ss = sswf.saved_subsystem + original = get_saveable_values(sswf.fallback, ps, tsidx) + tsidx in ss.identity_partitions && return original + ss.partition_count === nothing && return nothing + cnt = get(ss.partition_count, tsidx, 0) + cnt == 0 && return nothing + + toaw = TupleOfArraysWrapper(ss.timeseries_partition_templates[tsidx]) + for idx in ss.indexes_in_partition[tsidx] + toaw[ss.timeseries_params_map[idx].parameter_idx] = original[idx.parameter_idx] + end + return toaw +end + +function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( + sswf::SavedSubsystemWithFallback, ps, args...) + ss = sswf.saved_subsystem + for (tsidx, val) in args + if tsidx in ss.identity_partitions + ps = with_updated_parameter_timeseries_values(sswf.fallback, ps, tsidx => val) + continue + end + ss.partition_count === nothing && continue + cnt = get(ss.partition_count, tsidx, 0) + cnt == 0 && continue + + # now we know val isa TupleOfArraysWrapper + for idx in ss.indexes_in_partition[tsidx] + set_parameter!(ps, val[ss.timeseries_params_map[idx].parameter_idx], + ss.timeseries_idx_to_param_idx[idx]) + end + end + + return ps +end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 2d9ca7634..9e6fba1e2 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -59,6 +59,53 @@ SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false SymbolicIndexingInterface.constant_structure(::AbstractSolution) = true SymbolicIndexingInterface.state_values(A::AbstractNoTimeSolution) = A.u +function get_saved_subsystem(sol::T) where {T <: AbstractTimeseriesSolution} + hasfield(T, :saved_subsystem) ? sol.saved_subsystem : nothing +end + +for fn in [is_timeseries_parameter, timeseries_parameter_index, + with_updated_parameter_timeseries_values, get_saveable_values] + fname = nameof(fn) + mod = parentmodule(fn) + + @eval function $(mod).$(fname)(sol::AbstractTimeseriesSolution, args...) + ss = get_saved_subsystem(sol) + if ss === nothing + $(fn)(symbolic_container(sol), args...) + else + $(fn)(SavedSubsystemWithFallback(ss, symbolic_container(sol)), args...) + end + end +end + +function SymbolicIndexingInterface.state_values(sol::AbstractTimeseriesSolution, i) + ss = get_saved_subsystem(sol) + ss === nothing && return sol.u[i] + + original = state_values(sol.prob) + saved = sol.u[i] + if !(saved isa AbstractArray) + saved = [saved] + end + idxs = similar(saved, eltype(keys(ss.state_map))) + for (k, v) in ss.state_map + idxs[v] = k + end + replaced = remake_buffer(sol, original, idxs, saved) + return replaced +end + +function SymbolicIndexingInterface.state_values(sol::AbstractTimeseriesSolution) + ss = get_saved_subsystem(sol) + ss === nothing && return sol.u + return map(Base.Fix1(state_values, sol), eachindex(sol.u)) +end + +# Ambiguity resolution +function SymbolicIndexingInterface.state_values(sol::AbstractTimeseriesSolution, ::Colon) + state_values(sol) +end + Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, ::Colon) return A.u[:] end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index 709901c85..4cbeb3173 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -1,4 +1,6 @@ using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, Test +using StochasticDiffEq +using SymbolicIndexingInterface using ModelingToolkit: t_nounits as t, D_nounits as D using Plots: Plots, plot @@ -170,3 +172,155 @@ sol10 = sol(0.1, idxs = 2) end end end + +@testset "Saved subsystem" begin + @testset "Purely continuous ODE/DAE/SDE-solutions" begin + @variables x(t) y(t) + @parameters p + @mtkbuild sys = ODESystem([D(x) ~ x + p * y, D(y) ~ 2p + x^2], t) + @test length(unknowns(sys)) == 2 + xidx = variable_index(sys, x) + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), [p => 0.5]) + + @test SciMLBase.SavedSubsystem(sys, prob.p, []) === + SciMLBase.SavedSubsystem(sys, prob.p, nothing) === nothing + @test SciMLBase.SavedSubsystem(sys, prob.p, [x, y]) === nothing + @test begin + ss1 = SciMLBase.SavedSubsystem(sys, prob.p, [x]) + ss2 = SciMLBase.SavedSubsystem(sys, prob.p, [xidx]) + ss1.state_map == ss2.state_map + end + + ode_sol = solve(prob, Tsit5(); save_idxs = xidx) + subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx]) + # FIXME: hack for save_idxs + SciMLBase.@reset ode_sol.saved_subsystem = subsys + + @mtkbuild sys = ODESystem([D(x) ~ x + p * y, 1 ~ sin(y) + cos(x)], t) + xidx = variable_index(sys, x) + prob = DAEProblem(sys, [D(x) => x + p * y, D(y) => 1 / sqrt(1 - (1 - cos(x))^2)], + [x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0]) + dae_sol = solve(prob, DFBDF(); save_idxs = xidx) + subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx]) + # FIXME: hack for save_idxs + SciMLBase.@reset dae_sol.saved_subsystem = subsys + + @brownian a b + @mtkbuild sys = System([D(x) ~ x + p * y + x * a, D(y) ~ 2p + x^2 + y * b], t) + xidx = variable_index(sys, x) + prob = SDEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0), [p => 2.0]) + sde_sol = solve(prob, SOSRI(); save_idxs = xidx) + subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx]) + # FIXME: hack for save_idxs + SciMLBase.@reset sde_sol.saved_subsystem = subsys + + for sol in [ode_sol, dae_sol, sde_sol] + prob = sol.prob + subsys = sol.saved_subsystem + xvals = sol[x] + @test sol[x] == xvals + @test is_parameter(sol, p) + @test parameter_index(sol, p) == parameter_index(sys, p) + @test isequal(only(parameter_symbols(sol)), p) + @test is_independent_variable(sol, t) + + tmp = copy(prob.u0) + tmp[xidx] = xvals[2] + @test state_values(sol, 2) == tmp + @test state_values(sol) == [state_values(sol, i) for i in 1:length(sol)] + end + end + + @testset "ODE with callbacks" begin + @variables x(t) y(t) + @parameters p q(t) r(t) s(t) u(t) + evs = [0.1 => [q ~ q + 1, s ~ s - 1], 0.3 => [r ~ 2r, u ~ u / 2]] + @mtkbuild sys = ODESystem([D(x) ~ x + p * y, D(y) ~ 2p + x], t, [x, y], + [p, q, r, s, u], discrete_events = evs) + @test length(unknowns(sys)) == 2 + @test length(parameters(sys)) == 5 + @test is_timeseries_parameter(sys, q) + @test is_timeseries_parameter(sys, r) + xidx = variable_index(sys, x) + qidx = parameter_index(sys, q) + qpidx = timeseries_parameter_index(sys, q) + ridx = parameter_index(sys, r) + rpidx = timeseries_parameter_index(sys, r) + sidx = parameter_index(sys, s) + uidx = parameter_index(sys, u) + + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), + [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) + + @test SciMLBase.SavedSubsystem(sys, prob.p, [x, y, q, r, s, u]) === nothing + + sol = solve(prob; save_idxs = xidx) + xvals = sol[x] + subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r]) + qvals = sol.ps[q] + rvals = sol.ps[r] + # FIXME: hack for save_idxs + SciMLBase.@reset sol.saved_subsystem = subsys + discq = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(qvals))), + sol.discretes[qpidx.timeseries_idx].t, (1, 1)) + discr = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(rvals))), + sol.discretes[rpidx.timeseries_idx].t, (1, 1)) + SciMLBase.@reset sol.discretes.collection[qpidx.timeseries_idx] = discq + SciMLBase.@reset sol.discretes.collection[rpidx.timeseries_idx] = discr + + @test sol[x] == xvals + + @test all(Base.Fix1(is_parameter, sol), [p, q, r, s, u]) + @test all(Base.Fix1(is_timeseries_parameter, sol), [q, r]) + @test all(!Base.Fix1(is_timeseries_parameter, sol), [s, u]) + @test timeseries_parameter_index(sol, q) == + ParameterTimeseriesIndex(qpidx.timeseries_idx, (1, 1)) + @test timeseries_parameter_index(sol, r) == + ParameterTimeseriesIndex(rpidx.timeseries_idx, (1, 1)) + @test sol[q] == qvals + @test sol[r] == rvals + end + + @testset "SavedSubsystemWithFallback" begin + @variables x(t) y(t) + @parameters p q(t) r(t) s(t) u(t) + evs = [0.1 => [q ~ q + 1, s ~ s - 1], 0.3 => [r ~ 2r, u ~ u / 2]] + @mtkbuild sys = ODESystem([D(x) ~ x + p * y, D(y) ~ 2p + x^2], t, [x, y], + [p, q, r, s, u], discrete_events = evs) + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), + [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) + ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r]) + sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys) + xidx = variable_index(sys, x) + qidx = parameter_index(sys, q) + qpidx = timeseries_parameter_index(sys, q) + ridx = parameter_index(sys, r) + rpidx = timeseries_parameter_index(sys, r) + sidx = parameter_index(sys, s) + uidx = parameter_index(sys, u) + @test qpidx.timeseries_idx in ss.identity_partitions + @test !(rpidx.timeseries_idx in ss.identity_partitions) + @test timeseries_parameter_index(sswf, q) == timeseries_parameter_index(sys, q) + @test timeseries_parameter_index(sswf, s) == timeseries_parameter_index(sys, s) + + ptc = SciMLBase.create_parameter_timeseries_collection(sswf, prob.p, prob.tspan) + origptc = SciMLBase.create_parameter_timeseries_collection(sys, prob.p, prob.tspan) + + @test ptc[qpidx.timeseries_idx] == origptc[qpidx.timeseries_idx] + @test eltype(ptc[rpidx.timeseries_idx].u) <: SciMLBase.TupleOfArraysWrapper + vals = SciMLBase.get_saveable_values(sswf, prob.p, qpidx.timeseries_idx) + origvals = SciMLBase.get_saveable_values(sys, prob.p, qpidx.timeseries_idx) + @test typeof(vals) == typeof(origvals) + @test !(vals isa SciMLBase.TupleOfArraysWrapper) + + vals = SciMLBase.get_saveable_values(sswf, prob.p, rpidx.timeseries_idx) + @test vals isa SciMLBase.TupleOfArraysWrapper + @test vals.x isa Tuple{Vector{Float64}} + @test vals.x[1][1] == prob.ps[r] + + vals[(1, 1)] = 2prob.ps[r] + newp = with_updated_parameter_timeseries_values(sswf, + parameter_values(ptc), rpidx.timeseries_idx => vals) + @test newp[ridx] == 2prob.ps[r] + end +end