Skip to content

Commit

Permalink
Merge pull request #847 from AayushSabharwal/as/get-save-idxs
Browse files Browse the repository at this point in the history
feat: add `get_save_idxs_and_saved_subsystem`
  • Loading branch information
ChrisRackauckas authored Nov 6, 2024
2 parents 7b6c7ef + b592989 commit 0d263e5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,40 @@ function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(

return ps
end

"""
$(TYPEDSIGNATURES)
Given a SciMLProblem `prob` and (possibly symbolic) `save_idxs`, return the `save_idxs`
corresponding to the state variables and a `SavedSubsystem` to pass to `build_solution`.
The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
one is not required. `save_idxs` may be a scalar or `nothing`.
"""
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
if save_idxs === nothing
saved_subsystem = nothing
else
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
_save_idxs = [save_idxs]
else
_save_idxs = save_idxs
end
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
if saved_subsystem !== nothing
_save_idxs = get_saved_state_idxs(saved_subsystem)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
save_idxs = _save_idxs
end
end
end

return save_idxs, saved_subsystem
end
31 changes: 31 additions & 0 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,35 @@ end
parameter_values(ptc), rpidx.timeseries_idx => vals)
@test newp[ridx] == 2prob.ps[r]
end

@testset "get_save_idxs_and_saved_subsystem" begin
@variables x(t) y(t)
@parameters p q(t) r(t) s(t) u(t)
evs = [0.1 => [q ~ q + 1, s ~ s - 1], 0.3 => [r ~ 2r, u ~ u / 2]]
@mtkbuild sys = ODESystem([D(x) ~ x + p * y, D(y) ~ 2p + x^2], t, [x, y],
[p, q, r, s, u], discrete_events = evs)
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])

_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing)
@test _idxs === _ss === nothing
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1)
@test _idxs == 1
@test _ss isa SciMLBase.SavedSubsystem
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1])
@test _idxs == [1]
@test _ss isa SciMLBase.SavedSubsystem
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, x)
@test _idxs == 1
@test _ss isa SciMLBase.SavedSubsystem
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [x])
@test _idxs == [1]
@test _ss isa SciMLBase.SavedSubsystem
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [x, q])
@test _idxs == [1]
@test _ss isa SciMLBase.SavedSubsystem
_idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [q])
@test _idxs == Int[]
@test _ss isa SciMLBase.SavedSubsystem
end
end

0 comments on commit 0d263e5

Please sign in to comment.