Skip to content

Commit

Permalink
Merge pull request #853 from AayushSabharwal/as/getsetsym
Browse files Browse the repository at this point in the history
refactor: use `getsym`/`setsym` over `getu`/`setu`
  • Loading branch information
ChrisRackauckas authored Nov 8, 2024
2 parents 51261d1 + 909f93c commit ab44889
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ StableRNGs = "1.0"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.31"
SymbolicIndexingInterface = "0.3.34"
Tables = "1.11"
Zygote = "0.6.67"
julia = "1.10"
Expand Down
8 changes: 4 additions & 4 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
return getsym(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -501,7 +501,7 @@ Base.@propagate_inbounds function Base.getindex(
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
return getsym(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -522,15 +522,15 @@ function Base.setindex!(A::DEIntegrator, val, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
end
setu(A, sym)(A, val)
setsym(A, sym)(A, val)
end

function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple})
if symbolic_type(sym) == NotSymbolic() && any(x -> is_parameter(A, x), sym) ||
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.")
end
setu(A, sym)(A, val)
setsym(A, sym)(A, val)
end

### Integrator traits
Expand Down
8 changes: 4 additions & 4 deletions src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
return getsym(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -51,7 +51,7 @@ Base.@propagate_inbounds function Base.getindex(
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
return getsym(A, sym)(A)
end

function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...)
Expand All @@ -62,7 +62,7 @@ function ___internal_setindex!(A::AbstractSciMLProblem, val, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
end
return setu(A, sym)(A, val)
return setsym(A, sym)(A, val)
end

function ___internal_setindex!(
Expand All @@ -71,5 +71,5 @@ function ___internal_setindex!(
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.")
end
return setu(A, sym)(A, val)
return setsym(A, sym)(A, val)
end
6 changes: 3 additions & 3 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false)
(symbolic_type(defval) != NotSymbolic() || use_defaults)
defval
else
getu(prob, sym)(prob)
getsym(prob, sym)(prob)
end
end
newvals = anydict()
Expand Down Expand Up @@ -671,7 +671,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
# used, since any state symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state)
end
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
end
Expand All @@ -692,7 +692,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
# used, since any parameter symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
p[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state)
end
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end
Expand Down
8 changes: 4 additions & 4 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
end
end
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
return getu(sol, idxs)(state)
return getsym(sol, idxs)(state)
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
Expand All @@ -321,15 +321,15 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
end
end
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
return getu(sol, idxs)(state)
return getsym(sol, idxs)(state)
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getu(sol, idxs)
getter = getsym(sol, idxs)
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
return DiffEqArray(getter(interp_sol), t, p, sol)
Expand All @@ -353,7 +353,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
end
error_if_observed_derivative(sol, idxs, deriv)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
getter = getu(sol, idxs)
getter = getsym(sol, idxs)
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
return DiffEqArray(getter(interp_sol), t, p, sol)
Expand Down
6 changes: 3 additions & 3 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym)
if is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
return getsym(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -123,7 +123,7 @@ Base.@propagate_inbounds function Base.getindex(
is_parameter(A, sym)
error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.")
end
return getu(A, sym)(A)
return getsym(A, sym)(A)
end

Base.@propagate_inbounds function Base.getindex(
Expand Down Expand Up @@ -359,7 +359,7 @@ plottable_indices(x::Number) = 1
xvar = only(independent_variable_symbols(sol))
end
xvals = sol(ts; idxs = xvar).u
# xvals = getu(sol, xvar)(sol, tstart:tend)
# xvals = getsym(sol, xvar)(sol, tstart:tend)
yvals = getp(sol, yvar)(sol, tstart:tend)
tmpvals = map(func, xvals, yvals)
xvals = getindex.(tmpvals, 1)
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ sys2 = complete(sys2)
prob2 = ODEProblem(sys2, [], (0.0, 10.0))

bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w]))
getter = getu(bi)
getter = getsym(bi)

p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2
sum(getter(prob1, prob2))
Expand Down
22 changes: 11 additions & 11 deletions test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ timeseries_systems = [osys, ssys, jsys]
((indp.X, indp.Y), Tuple(u[uidxs]), (4.0, 4.0))
((:X, :Y), Tuple(u[uidxs]), (4.0, 4.0))
(Tuple(uidxs), Tuple(u[uidxs]), (4.0, 4.0))]
get = getu(indp, sym)
set! = setu(indp, sym)
get = getsym(indp, sym)
set! = setsym(indp, sym)
@inferred get(valp)
@test get(valp) == val
if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray}
Expand Down Expand Up @@ -153,12 +153,12 @@ timeseries_systems = [osys, ssys, jsys]
([X, indp.Y, :XY, X * Y], [u[uidxs]..., sum(u), prod(u)])
((X, indp.Y, :XY, X * Y), (u[uidxs]..., sum(u), prod(u)))
(X * Y, prod(u))]
get = getu(indp, sym)
get = getsym(indp, sym)
@test get(valp) == val
end
end

getter = getu(indp, [])
getter = getsym(indp, [])
@test getter(valp) == []

p = getindex.((Dict(p_vals),), [kp, kd, k1, k2])
Expand Down Expand Up @@ -264,7 +264,7 @@ end
true)
(X * Y, xvals .* yvals,
false, true)]
get = getu(indp, sym)
get = getsym(indp, sym)
if check_inference
@inferred get(valp)
end
Expand Down Expand Up @@ -416,9 +416,9 @@ end
[[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false)
]
if check_inference
@inferred getu(prob, sym)(sol)
@inferred getsym(prob, sym)(sol)
end
@test getu(prob, sym)(sol) == val
@test getsym(prob, sym)(sol) == val
end
end

Expand Down Expand Up @@ -454,8 +454,8 @@ end
((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])),
(x_newval[1:2], (y_newval, x_newval[3])), false)
]
getter = getu(prob, sym)
setter! = setu(prob, sym)
getter = getsym(prob, sym)
setter! = setsym(prob, sym)
if check_inference
@inferred getter(prob)
end
Expand Down Expand Up @@ -818,7 +818,7 @@ end
([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false),
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false)
]
getter = getu(sys, sym)
getter = getsym(sys, sym)
if check_inference
@inferred getter(sol)
end
Expand Down Expand Up @@ -853,7 +853,7 @@ end
([2x, 3xd1], [2_xval, 3_xd1val], true),
((2x, 3xd2), (2_xval, 3_xd2val), true)
]
getter = getu(sys, sym)
getter = getsym(sys, sym)
@test_throws Exception getter(sol)
for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2]
@test_throws Exception getter(sol, subidx)
Expand Down
18 changes: 9 additions & 9 deletions test/downstream/integrator_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,21 @@ integrator = init(prob, Tsit5(), save_everystep = false)
@test integrator[x] isa Vector{Float64}
@test integrator[@nonamespace sys.x] isa Vector{Float64}

getx = getu(integrator, x)
gety = getu(integrator, :y)
get_arr = getu(integrator, [x, y])
get_tuple = getu(integrator, (x, y))
get_obs = getu(integrator, x[1] / p[1])
getx = getsym(integrator, x)
gety = getsym(integrator, :y)
get_arr = getsym(integrator, [x, y])
get_tuple = getsym(integrator, (x, y))
get_obs = getsym(integrator, x[1] / p[1])
@test getx(integrator) == [1.0, 2.0, 3.0]
@test gety(integrator) == 1.0
@test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0]
@test get_tuple(integrator) == ([1.0, 2.0, 3.0], 1.0)
@test get_obs(integrator) == 1.0

setx! = setu(integrator, x)
sety! = setu(integrator, :y)
set_arr! = setu(integrator, [x, y])
set_tuple! = setu(integrator, (x, y))
setx! = setsym(integrator, x)
sety! = setsym(integrator, :y)
set_arr! = setsym(integrator, [x, y])
set_tuple! = setsym(integrator, (x, y))

setx!(integrator, [4.0, 5.0, 6.0])
@test getx(integrator) == [4.0, 5.0, 6.0]
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ for (sys, prob) in zip(syss, probs)
@inferred typeof(prob) remake(prob)

baseType = Base.typename(typeof(prob)).wrapper
ugetter = getu(prob, [x, y, z])
ugetter = getsym(prob, [x, y, z])
prob2 = @inferred baseType remake(prob; u0 = [x => 2.0, y => 3.0, z => 4.0])
@test ugetter(prob2) == [2.0, 3.0, 4.0]
prob2 = @inferred baseType remake(prob; u0 = [sys.x => 2.0, sys.y => 3.0, sys.z => 4.0])
Expand Down
60 changes: 30 additions & 30 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@ oprob[sys.y] = 10.0
oprob[:z] = 1.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 1.0

getx = getu(oprob, x)
gety = getu(oprob, :y)
get_arr = getu(oprob, [x, y])
get_tuple = getu(oprob, (y, z))
get_obs = getu(oprob, sys.x + sys.z + t + σ)
getx = getsym(oprob, x)
gety = getsym(oprob, :y)
get_arr = getsym(oprob, [x, y])
get_tuple = getsym(oprob, (y, z))
get_obs = getsym(oprob, sys.x + sys.z + t + σ)
@test getx(oprob) == 10.0
@test gety(oprob) == 10.0
@test get_arr(oprob) == [10.0, 10.0]
@test get_tuple(oprob) == (10.0, 1.0)
@test get_obs(oprob) == 22.0

setx! = setu(oprob, x)
sety! = setu(oprob, :y)
set_arr! = setu(oprob, [x, y])
set_tuple! = setu(oprob, (y, z))
setx! = setsym(oprob, x)
sety! = setsym(oprob, :y)
set_arr! = setsym(oprob, [x, y])
set_tuple! = setsym(oprob, (y, z))

setx!(oprob, 11.0)
@test getx(oprob) == 11.0
Expand Down Expand Up @@ -168,21 +168,21 @@ sprob[noise_sys.y] = 10.0
sprob[:z] = 1.0
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 1.0

getx = getu(sprob, x)
gety = getu(sprob, :y)
get_arr = getu(sprob, [x, y])
get_tuple = getu(sprob, (y, z))
get_obs = getu(sprob, sys.x + sys.z + t + σ)
getx = getsym(sprob, x)
gety = getsym(sprob, :y)
get_arr = getsym(sprob, [x, y])
get_tuple = getsym(sprob, (y, z))
get_obs = getsym(sprob, sys.x + sys.z + t + σ)
@test getx(sprob) == 10.0
@test gety(sprob) == 10.0
@test get_arr(sprob) == [10.0, 10.0]
@test get_tuple(sprob) == (10.0, 1.0)
@test get_obs(sprob) == 22.0

setx! = setu(sprob, x)
sety! = setu(sprob, :y)
set_arr! = setu(sprob, [x, y])
set_tuple! = setu(sprob, (y, z))
setx! = setsym(sprob, x)
sety! = setsym(sprob, :y)
set_arr! = setsym(sprob, [x, y])
set_tuple! = setsym(sprob, (y, z))

setx!(sprob, 11.0)
@test getx(sprob) == 11.0
Expand Down Expand Up @@ -228,9 +228,9 @@ eprob = EnsembleProblem(oprob)
@test eprob.ps[p] == 1.0
@test eprob.ps[:p] == 1.0
@test eprob.ps[osys.p] == 1.0
@test getu(eprob, X)(eprob) == 0.1
@test getu(eprob, :X)(eprob) == 0.1
@test getu(eprob, osys.X)(eprob) == 0.1
@test getsym(eprob, X)(eprob) == 0.1
@test getsym(eprob, :X)(eprob) == 0.1
@test getsym(eprob, osys.X)(eprob) == 0.1
@test getp(eprob, p)(eprob) == 1.0
@test getp(eprob, :p)(eprob) == 1.0
@test getp(eprob, osys.p)(eprob) == 1.0
Expand All @@ -247,11 +247,11 @@ eprob = EnsembleProblem(oprob)
@test eprob.ps[:p] == 0.1
@test_nowarn eprob.ps[osys.p] = 0.0
@test eprob.ps[osys.p] == 0.0
@test_nowarn setu(eprob, X)(eprob, 0.1)
@test_nowarn setsym(eprob, X)(eprob, 0.1)
@test eprob[X] == 0.1
@test_nowarn setu(eprob, :X)(eprob, 0.0)
@test_nowarn setsym(eprob, :X)(eprob, 0.0)
@test eprob[:X] == 0.0
@test_nowarn setu(eprob, osys.X)(eprob, 0.1)
@test_nowarn setsym(eprob, osys.X)(eprob, 0.1)
@test eprob[osys.X] == 0.1
@test_nowarn setp(eprob, p)(eprob, 0.1)
@test eprob.ps[p] == 0.1
Expand Down Expand Up @@ -285,9 +285,9 @@ prob = SteadyStateProblem(osys, u0, ps)
@test prob[X] == prob[osys.X] == prob[:X] == 0.1
@test prob[X2] == prob[osys.X2] == prob[:X2] == 0.2
@test prob[[X, X2]] == prob[[osys.X, osys.X2]] == prob[[:X, :X2]] == [0.1, 0.2]
@test getu(prob, X)(prob) == getu(prob, osys.X)(prob) == getu(prob, :X)(prob) == 0.1
@test getu(prob, X2)(prob) == getu(prob, osys.X2)(prob) == getu(prob, :X2)(prob) == 0.2
@test getu(prob, [X, X2])(prob) == getu(prob, [osys.X, osys.X2])(prob) ==
getu(prob, [:X, :X2])(prob) == [0.1, 0.2]
@test getu(prob, (X, X2))(prob) == getu(prob, (osys.X, osys.X2))(prob) ==
getu(prob, (:X, :X2))(prob) == (0.1, 0.2)
@test getsym(prob, X)(prob) == getsym(prob, osys.X)(prob) == getsym(prob, :X)(prob) == 0.1
@test getsym(prob, X2)(prob) == getsym(prob, osys.X2)(prob) == getsym(prob, :X2)(prob) == 0.2
@test getsym(prob, [X, X2])(prob) == getsym(prob, [osys.X, osys.X2])(prob) ==
getsym(prob, [:X, :X2])(prob) == [0.1, 0.2]
@test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) ==
getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)

0 comments on commit ab44889

Please sign in to comment.