Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support symbolic indexing of a subset of the system #809

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ include("problems/problem_interface.jl")
include("problems/optimization_problems.jl")

include("clock.jl")
include("solutions/save_idxs.jl")
include("solutions/basic_solutions.jl")
include("solutions/nonlinear_solutions.jl")
include("solutions/ode_solutions.jl")
Expand Down
17 changes: 10 additions & 7 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,22 @@ Otherwise the integrator is allowed to skip recalculating the interpolation.

# Arguments

- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback)
or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation
of the interpolations.
- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the
initialization that is done post callback. The default value of `nothing` means that the initialization choice
used for the DAE should be performed post-callback.
- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback)
or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation
of the interpolations.
- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the
initialization that is done post callback. The default value of `nothing` means that the initialization choice
used for the DAE should be performed post-callback.
"""
function reeval_internals_due_to_modification!(
integrator::DEIntegrator, continuous_modification;
callback_initializealg = nothing)
reeval_internals_due_to_modification!(integrator::DEIntegrator)
end
reeval_internals_due_to_modification!(integrator::DEIntegrator; callback_initializealg = nothing) = nothing
function reeval_internals_due_to_modification!(
integrator::DEIntegrator; callback_initializealg = nothing)
nothing
end

"""
set_t!(integrator::DEIntegrator, t)
Expand Down
4 changes: 2 additions & 2 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2693,8 +2693,8 @@ function SplitFunction{iip, specialize}(f1, f2;
f1.jac_prototype :
nothing,
W_prototype = __has_W_prototype(f1) ?
f1.W_prototype :
nothing,
f1.W_prototype :
nothing,
sparsity = __has_sparsity(f1) ? f1.sparsity :
jac_prototype,
Wfact = __has_Wfact(f1) ? f1.Wfact : nothing,
Expand Down
115 changes: 46 additions & 69 deletions src/solutions/dae_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
"""
struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType} <:
struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType, V} <:
AbstractDAESolution{T, N, uType}
u::uType
du::duType
Expand All @@ -42,6 +42,31 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateT
tslocation::Int
stats::S
retcode::ReturnCode.T
saved_subsystem::V
end

function DAESolution{T, N}(u, du, u_analytic, errors, t, k, prob, alg, interp, dense,
tslocation, stats, retcode, saved_subsystem) where {T, N}
return DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors),
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
typeof(saved_subsystem)}(
u, du, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, stats,
retcode, saved_subsystem
)
end

function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: DAESolution{T, N}}
DAESolution{T, N}
end

function ConstructionBase.setproperties(sol::DAESolution, patch::NamedTuple)
u = get(patch, :u, sol.u)
N = u === nothing ? 2 : ndims(eltype(u)) + 1
T = eltype(eltype(u))
patch = merge(getproperties(sol), patch)
return DAESolution{T, N}(patch.u, patch.du, patch.u_analytic, patch.errors, patch.t,
patch.k, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation,
patch.stats, patch.retcode, patch.saved_subsystem)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Symbol)
Expand All @@ -65,13 +90,14 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
retcode = ReturnCode.Default,
destats = missing,
stats = nothing,
saved_subsystem = nothing,
kwargs...)
T = eltype(eltype(u))

if prob.u0 === nothing
N = 2
else
N = length((size(prob.u0)..., length(u)))
N = ndims(eltype(u)) + 1
end

if !ismissing(destats)
Expand All @@ -88,7 +114,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
errors = Dict{Symbol, real(eltype(prob.u0))}()

sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors),
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}(
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
typeof(saved_subsystem)}(
u,
du,
u_analytic,
Expand All @@ -101,7 +128,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
dense,
0,
stats,
retcode)
retcode,
saved_subsystem)

if calculate_error
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
Expand All @@ -110,15 +138,17 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
sol
else
DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t),
typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}(
typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
typeof(saved_subsystem)}(
u, du,
nothing,
nothing, t, k,
prob, alg,
interp,
dense, 0,
stats,
retcode)
retcode,
saved_subsystem)
end
end

Expand Down Expand Up @@ -161,76 +191,23 @@ function calculate_solution_errors!(sol::AbstractDAESolution;
end

function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors),
typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
typeof(sol.stats), typeof(sol.k)}(sol.u,
sol.du,
u_analytic,
errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
sol.retcode)
@reset sol.u_analytic = u_analytic
return @set sol.errors = errors
end

function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u,
sol.du,
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
retcode)
return @set sol.retcode = retcode
end

function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
typeof(sol.interp), typeof(sol.stats), typeof(k)}(sol.u,
sol.du,
sol.u_analytic,
sol.errors,
sol.t,
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
tslocation,
sol.stats,
sol.retcode)
return @set sol.tslocation = tslocation
end

function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N}
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u[I],
sol.du[I],
sol.u_analytic ===
nothing ?
nothing :
sol.u_analytic[I],
sol.errors,
sol.t[I],
sol.k[I],
sol.prob,
sol.alg,
sol.interp,
false,
sol.tslocation,
sol.stats,
sol.retcode)
@reset sol.u = sol.u[I]
@reset sol.du = sol.du[I]
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
@reset sol.t = sol.t[I]
@reset sol.k = sol.dense ? sol.k[I] : sol.k
return @set sol.dense = false
end
48 changes: 35 additions & 13 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved
in this solution, or `nothing` if all variables are saved. Here "variables" refers to
both continuous-time state variables and timeseries parameters.
"""
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
AC <: Union{Nothing, Vector{Int}}, R, O} <:
AC <: Union{Nothing, Vector{Int}}, R, O, V} <:
AbstractODESolution{T, N, uType}
u::uType
u_analytic::uType2
Expand All @@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A,
retcode::ReturnCode.T
resid::R
original::O
saved_subsystem::V
end

function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
Expand All @@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
patch = merge(getproperties(sol), patch)
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
patch.alg_choice, patch.retcode, patch.resid, patch.original)
patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem)
end

Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
Expand All @@ -154,12 +158,12 @@ end
function ODESolution{T, N}(
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
tslocation, stats, alg_choice, retcode, resid = nothing,
original = nothing) where {T, N}
original = nothing, saved_subsystem = nothing) where {T, N}
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
typeof(stats), typeof(alg_choice), typeof(resid),
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode, resid, original)
typeof(stats), typeof(alg_choice), typeof(resid), typeof(original),
typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem)
end

error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
Expand Down Expand Up @@ -409,15 +413,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac
# public API, used by MTK
"""
get_saveable_values(sys, ps, timeseries_idx)

Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by
`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything.
"""
function get_saveable_values(sys, ps, timeseries_idx)
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
end

"""
save_discretes!(integ::DEIntegrator, timeseries_idx)

Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
get the values to save. If it returns `nothing`, then the save does not happen.
"""
function save_discretes!(integ::DEIntegrator, timeseries_idx)
save_discretes!(integ.sol, current_time(integ),
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
timeseries_idx)
inner_sol = get_sol(integ)
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
vals === nothing && return
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
end

save_discretes!(args...) = nothing
Expand Down Expand Up @@ -451,6 +465,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
interp = LinearInterpolation(t, u),
retcode = ReturnCode.Default, destats = missing, stats = nothing,
resid = nothing, original = nothing,
saved_subsystem = nothing,
kwargs...)
T = eltype(eltype(u))

Expand Down Expand Up @@ -482,7 +497,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},

ps = parameter_values(prob)
if has_sys(prob.f)
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
sswf = if saved_subsystem === nothing
prob.f.sys
else
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
end
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
else
discretes = nothing
end
Expand All @@ -503,7 +523,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg_choice,
retcode,
resid,
original)
original,
saved_subsystem)
if calculate_error
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
dense_errors = dense_errors)
Expand All @@ -524,7 +545,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg_choice,
retcode,
resid,
original)
original,
saved_subsystem)
end
end

Expand Down Expand Up @@ -593,7 +615,7 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
@reset sol.t = sol.t[I]
@reset sol.k = sol.dense ? sol.k[I] : sol.k
return @set sol.alg = false
return @set sol.dense = false
end

mask_discretes(::Nothing, _, _...) = nothing
Expand Down
Loading
Loading