From b592989098552f589640c8c8abd4964b0d742912 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 6 Nov 2024 15:36:48 +0530 Subject: [PATCH] feat: add `get_save_idxs_and_saved_subsystem` --- src/solutions/save_idxs.jl | 37 +++++++++++++++++++++++++++ test/downstream/solution_interface.jl | 31 ++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index 0b154d538..db21ac72b 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -347,3 +347,40 @@ function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( return ps end + +""" + $(TYPEDSIGNATURES) + +Given a SciMLProblem `prob` and (possibly symbolic) `save_idxs`, return the `save_idxs` +corresponding to the state variables and a `SavedSubsystem` to pass to `build_solution`. + +The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case +one is not required. `save_idxs` may be a scalar or `nothing`. +""" +function get_save_idxs_and_saved_subsystem(prob, save_idxs) + if save_idxs === nothing + saved_subsystem = nothing + else + if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + _save_idxs = [save_idxs] + else + _save_idxs = save_idxs + end + saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs) + if saved_subsystem !== nothing + _save_idxs = get_saved_state_idxs(saved_subsystem) + if isempty(_save_idxs) + # no states to save + save_idxs = Int[] + elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() + # only a single state to save, and save it as a scalar timeseries instead of + # single-element array + save_idxs = only(_save_idxs) + else + save_idxs = _save_idxs + end + end + end + + return save_idxs, saved_subsystem +end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index cb6a07780..aa62f23f7 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -327,4 +327,35 @@ end parameter_values(ptc), rpidx.timeseries_idx => vals) @test newp[ridx] == 2prob.ps[r] end + + @testset "get_save_idxs_and_saved_subsystem" begin + @variables x(t) y(t) + @parameters p q(t) r(t) s(t) u(t) + evs = [0.1 => [q ~ q + 1, s ~ s - 1], 0.3 => [r ~ 2r, u ~ u / 2]] + @mtkbuild sys = ODESystem([D(x) ~ x + p * y, D(y) ~ 2p + x^2], t, [x, y], + [p, q, r, s, u], discrete_events = evs) + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0), + [p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0]) + + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, nothing) + @test _idxs === _ss === nothing + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, 1) + @test _idxs == 1 + @test _ss isa SciMLBase.SavedSubsystem + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [1]) + @test _idxs == [1] + @test _ss isa SciMLBase.SavedSubsystem + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, x) + @test _idxs == 1 + @test _ss isa SciMLBase.SavedSubsystem + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [x]) + @test _idxs == [1] + @test _ss isa SciMLBase.SavedSubsystem + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [x, q]) + @test _idxs == [1] + @test _ss isa SciMLBase.SavedSubsystem + _idxs, _ss = SciMLBase.get_save_idxs_and_saved_subsystem(prob, [q]) + @test _idxs == Int[] + @test _ss isa SciMLBase.SavedSubsystem + end end