Skip to content

Commit

Permalink
Merge pull request #834 from AayushSabharwal/as/save-idxs-fix
Browse files Browse the repository at this point in the history
feat: add `get_saved_state_idxs`, handle problems with no system in SavedSubsystem constructor
  • Loading branch information
ChrisRackauckas authored Oct 28, 2024
2 parents bf913e7 + b0fd60a commit 699d833
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
end

function is_empty_indp(indp)
isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) &&
isempty(independent_variable_symbols(indp))
end

# Everything from this point on is public API

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -104,6 +111,12 @@ function SavedSubsystem(indp, pobj, saved_idxs)
return nothing
end

# this is required because problems with no system have an empty `SymbolCache`
# as their symbolic container.
if is_empty_indp(indp)
return nothing
end

# array state symbolics must be scalarized
saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym
if symbolic_type(sym) == NotSymbolic()
Expand Down Expand Up @@ -226,6 +239,20 @@ function SavedSubsystem(indp, pobj, saved_idxs)
timeseries_partition_templates, indexes_in_partition, ts_idx_to_count)
end

"""
$(TYPEDSIGNATURES)
Given a `SavedSubsystem`, return the subset of state indexes of the original system that are
saved, in the order they are saved.
"""
function get_saved_state_idxs(ss::SavedSubsystem)
idxs = Vector{valtype(ss.state_map)}(undef, length(ss.state_map))
for (k, v) in ss.state_map
idxs[v] = k
end
return idxs
end

"""
$(TYPEDEF)
Expand Down
4 changes: 4 additions & 0 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ end

ode_sol = solve(prob, Tsit5(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]

# FIXME: hack for save_idxs
SciMLBase.@reset ode_sol.saved_subsystem = subsys

Expand Down Expand Up @@ -257,6 +259,7 @@ end
sol = solve(prob; save_idxs = xidx)
xvals = sol[x]
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
qvals = sol.ps[q]
rvals = sol.ps[r]
# FIXME: hack for save_idxs
Expand Down Expand Up @@ -290,6 +293,7 @@ end
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 5.0),
[p => 0.5, q => 0.0, r => 1.0, s => 10.0, u => 4096.0])
ss = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, s, r])
@test SciMLBase.get_saved_state_idxs(ss) == [xidx]
sswf = SciMLBase.SavedSubsystemWithFallback(ss, sys)
xidx = variable_index(sys, x)
qidx = parameter_index(sys, q)
Expand Down

0 comments on commit 699d833

Please sign in to comment.