Skip to content

Commit

Permalink
feat: support symbolic indexing of a subset of the system
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 14, 2024
1 parent 17f4548 commit d47c3bf
Show file tree
Hide file tree
Showing 4 changed files with 529 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ include("problems/problem_interface.jl")
include("problems/optimization_problems.jl")

include("clock.jl")
include("solutions/save_idxs.jl")
include("solutions/basic_solutions.jl")
include("solutions/nonlinear_solutions.jl")
include("solutions/ode_solutions.jl")
Expand Down
93 changes: 81 additions & 12 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved
in this solution, or `nothing` if all variables are saved. Here "variables" refers to
both continuous-time state variables and timeseries parameters.
"""
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
AC <: Union{Nothing, Vector{Int}}, R, O} <:
AC <: Union{Nothing, Vector{Int}}, R, O, V} <:
AbstractODESolution{T, N, uType}
u::uType
u_analytic::uType2
Expand All @@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A,
retcode::ReturnCode.T
resid::R
original::O
saved_subsystem::V
end

function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
Expand All @@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
patch = merge(getproperties(sol), patch)
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
patch.alg_choice, patch.retcode, patch.resid, patch.original)
patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
Expand All @@ -154,12 +158,12 @@ end
function ODESolution{T, N}(
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
tslocation, stats, alg_choice, retcode, resid = nothing,
original = nothing) where {T, N}
original = nothing, saved_subsystem = nothing) where {T, N}
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
typeof(stats), typeof(alg_choice), typeof(resid),
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode, resid, original)
typeof(stats), typeof(alg_choice), typeof(resid), typeof(original),
typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem)
end

error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
Expand All @@ -180,6 +184,53 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
Timeseries()
end

const SolutionWithSavedSubsystem = ODESolution{T1,
T2,
T3,
T4,
T5,
T6,
T7,
T8,
T9,
T10,
T11,
T12,
T13,
T14,
T15,
T16} where {
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16 <: SavedSubsystem}

for method in [is_timeseries_parameter, timeseries_parameter_index,
with_updated_parameter_timeseries_values, get_saveable_values]
fname = nameof(method)
mod = parentmodule(method)
@eval function $(mod).$(fname)(sol::SolutionWithSavedSubsystem, args...)
$(method)(SavedSubsystemWithFallback(sol.saved_subsystem, symbolic_container(sol)),
args...)
end
end

function SymbolicIndexingInterface.state_values(sol::SolutionWithSavedSubsystem, i)
original = state_values(sol.prob)
saved = sol.u[i]
if !(saved isa AbstractArray)
saved = [saved]
end
ss = sol.saved_subsystem
idxs = similar(saved, eltype(keys(ss.state_map)))
for (k, v) in ss.state_map
idxs[v] = k
end
replaced = remake_buffer(sol, original, idxs, saved)
return replaced
end

function SymbolicIndexingInterface.state_values(sol::SolutionWithSavedSubsystem)
return map(Base.Fix1(state_values, sol), eachindex(sol.u))
end

function _hold_discrete(disc_u, disc_t, t::Number)
idx = searchsortedlast(disc_t, t)
if idx == firstindex(disc_t) - 1
Expand Down Expand Up @@ -409,15 +460,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac
# public API, used by MTK
"""
get_saveable_values(sys, ps, timeseries_idx)
Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by
`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything.
"""
function get_saveable_values(sys, ps, timeseries_idx)
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
end

"""
save_discretes!(integ::DEIntegrator, timeseries_idx)
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
get the values to save. If it returns `nothing`, then the save does not happen.
"""
function save_discretes!(integ::DEIntegrator, timeseries_idx)
save_discretes!(integ.sol, current_time(integ),
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
timeseries_idx)
inner_sol = get_sol(integ)
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
vals === nothing && return
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
end

save_discretes!(args...) = nothing
Expand Down Expand Up @@ -451,6 +512,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
interp = LinearInterpolation(t, u),
retcode = ReturnCode.Default, destats = missing, stats = nothing,
resid = nothing, original = nothing,
saved_subsystem = nothing,
kwargs...)
T = eltype(eltype(u))

Expand Down Expand Up @@ -482,7 +544,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},

ps = parameter_values(prob)
if has_sys(prob.f)
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
sswf = if saved_subsystem === nothing
prob.f.sys
else
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
end
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
else
discretes = nothing
end
Expand All @@ -503,7 +570,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg_choice,
retcode,
resid,
original)
original,
saved_subsystem)
if calculate_error
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
dense_errors = dense_errors)
Expand All @@ -524,7 +592,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg_choice,
retcode,
resid,
original)
original,
saved_subsystem)
end
end

Expand Down
Loading

0 comments on commit d47c3bf

Please sign in to comment.