From 909f93cf52dc9192334e4e7438d834dc26786e97 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Nov 2024 16:01:01 +0530 Subject: [PATCH] refactor: use `getsym`/`setsym` over `getu`/`setu` --- Project.toml | 2 +- src/integrator_interface.jl | 8 +-- src/problems/problem_interface.jl | 8 +-- src/remake.jl | 6 +-- src/solutions/ode_solutions.jl | 8 +-- src/solutions/solution_interface.jl | 6 +-- test/downstream/adjoints.jl | 2 +- test/downstream/comprehensive_indexing.jl | 22 ++++----- test/downstream/integrator_indexing.jl | 18 +++---- test/downstream/modelingtoolkit_remake.jl | 2 +- test/downstream/problem_interface.jl | 60 +++++++++++------------ 11 files changed, 71 insertions(+), 71 deletions(-) diff --git a/Project.toml b/Project.toml index b599088a7..d298c7eb2 100644 --- a/Project.toml +++ b/Project.toml @@ -88,7 +88,7 @@ StableRNGs = "1.0" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.31" +SymbolicIndexingInterface = "0.3.34" Tables = "1.11" Zygote = "0.6.67" julia = "1.10" diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 45777a27c..b38790410 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -492,7 +492,7 @@ Base.@propagate_inbounds function Base.getindex(A::DEIntegrator, sym) if is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.") end - return getu(A, sym)(A) + return getsym(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( @@ -501,7 +501,7 @@ Base.@propagate_inbounds function Base.getindex( is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `integrator.ps[$sym]` for parameter indexing.") end - return getu(A, sym)(A) + return getsym(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( @@ -522,7 +522,7 @@ function Base.setindex!(A::DEIntegrator, val, sym) if is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.") end - setu(A, sym)(A, val) + setsym(A, sym)(A, val) end function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple}) @@ -530,7 +530,7 @@ function Base.setindex!(A::DEIntegrator, val, sym::Union{AbstractArray, Tuple}) is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `integrator.ps[$sym] = $val` for parameter indexing.") end - setu(A, sym)(A, val) + setsym(A, sym)(A, val) end ### Integrator traits diff --git a/src/problems/problem_interface.jl b/src/problems/problem_interface.jl index ece143ab7..a704d978b 100644 --- a/src/problems/problem_interface.jl +++ b/src/problems/problem_interface.jl @@ -42,7 +42,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractSciMLProblem, sym) if is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.") end - return getu(A, sym)(A) + return getsym(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( @@ -51,7 +51,7 @@ Base.@propagate_inbounds function Base.getindex( is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `prob.ps[$sym]` for parameter indexing.") end - return getu(A, sym)(A) + return getsym(A, sym)(A) end function Base.setindex!(prob::AbstractSciMLProblem, args...; kwargs...) @@ -62,7 +62,7 @@ function ___internal_setindex!(A::AbstractSciMLProblem, val, sym) if is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.") end - return setu(A, sym)(A, val) + return setsym(A, sym)(A, val) end function ___internal_setindex!( @@ -71,5 +71,5 @@ function ___internal_setindex!( is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `prob.ps[$sym] = $val` for parameter indexing.") end - return setu(A, sym)(A, val) + return setsym(A, sym)(A, val) end diff --git a/src/remake.jl b/src/remake.jl index 0ffd91003..8ed91237f 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -593,7 +593,7 @@ function fill_u0(prob, u0; defs = nothing, use_defaults = false) (symbolic_type(defval) != NotSymbolic() || use_defaults) defval else - getu(prob, sym)(prob) + getsym(prob, sym)(prob) end end newvals = anydict() @@ -671,7 +671,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0) # used, since any state symbols in the expression were substituted out earlier. temp_state = ProblemState(; u = state_values(prob), p = p, t = t0) for (k, v) in u0 - u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) + u0[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state) end return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p end @@ -692,7 +692,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0) # used, since any parameter symbols in the expression were substituted out earlier. temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0) for (k, v) in p - p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state) + p[k] = symbolic_type(v) === NotSymbolic() ? v : getsym(prob, v)(temp_state) end return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p)) end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index f947d5895..ccc3d7b77 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -296,7 +296,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, end end state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) - return getu(sol, idxs)(state) + return getsym(sol, idxs)(state) end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector, @@ -321,7 +321,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect end end state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) - return getu(sol, idxs)(state) + return getsym(sol, idxs)(state) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, @@ -329,7 +329,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - getter = getu(sol, idxs) + getter = getsym(sol, idxs) if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol) return DiffEqArray(getter(interp_sol), t, p, sol) @@ -353,7 +353,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, end error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - getter = getu(sol, idxs) + getter = getsym(sol, idxs) if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol) return DiffEqArray(getter(interp_sol), t, p, sol) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 7cca426f5..f86864caf 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -114,7 +114,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) if is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.") end - return getu(A, sym)(A) + return getsym(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( @@ -123,7 +123,7 @@ Base.@propagate_inbounds function Base.getindex( is_parameter(A, sym) error("Indexing with parameters is deprecated. Use `sol.ps[$sym]` for parameter indexing.") end - return getu(A, sym)(A) + return getsym(A, sym)(A) end Base.@propagate_inbounds function Base.getindex( @@ -359,7 +359,7 @@ plottable_indices(x::Number) = 1 xvar = only(independent_variable_symbols(sol)) end xvals = sol(ts; idxs = xvar).u - # xvals = getu(sol, xvar)(sol, tstart:tend) + # xvals = getsym(sol, xvar)(sol, tstart:tend) yvals = getp(sol, yvar)(sol, tstart:tend) tmpvals = map(func, xvals, yvals) xvals = getindex.(tmpvals, 1) diff --git a/test/downstream/adjoints.jl b/test/downstream/adjoints.jl index caf2a54f3..4e75e19b6 100644 --- a/test/downstream/adjoints.jl +++ b/test/downstream/adjoints.jl @@ -80,7 +80,7 @@ sys2 = complete(sys2) prob2 = ODEProblem(sys2, [], (0.0, 10.0)) bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w])) -getter = getu(bi) +getter = getsym(bi) p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2 sum(getter(prob1, prob2)) diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 69db72d15..f5c519ce1 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -119,8 +119,8 @@ timeseries_systems = [osys, ssys, jsys] ((indp.X, indp.Y), Tuple(u[uidxs]), (4.0, 4.0)) ((:X, :Y), Tuple(u[uidxs]), (4.0, 4.0)) (Tuple(uidxs), Tuple(u[uidxs]), (4.0, 4.0))] - get = getu(indp, sym) - set! = setu(indp, sym) + get = getsym(indp, sym) + set! = setsym(indp, sym) @inferred get(valp) @test get(valp) == val if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray} @@ -153,12 +153,12 @@ timeseries_systems = [osys, ssys, jsys] ([X, indp.Y, :XY, X * Y], [u[uidxs]..., sum(u), prod(u)]) ((X, indp.Y, :XY, X * Y), (u[uidxs]..., sum(u), prod(u))) (X * Y, prod(u))] - get = getu(indp, sym) + get = getsym(indp, sym) @test get(valp) == val end end - getter = getu(indp, []) + getter = getsym(indp, []) @test getter(valp) == [] p = getindex.((Dict(p_vals),), [kp, kd, k1, k2]) @@ -264,7 +264,7 @@ end true) (X * Y, xvals .* yvals, false, true)] - get = getu(indp, sym) + get = getsym(indp, sym) if check_inference @inferred get(valp) end @@ -416,9 +416,9 @@ end [[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false) ] if check_inference - @inferred getu(prob, sym)(sol) + @inferred getsym(prob, sym)(sol) end - @test getu(prob, sym)(sol) == val + @test getsym(prob, sym)(sol) == val end end @@ -454,8 +454,8 @@ end ((x[1:2], (y_idx, x[3])), (x_probval[1:2], (y_probval, x_probval[3])), (x_newval[1:2], (y_newval, x_newval[3])), false) ] - getter = getu(prob, sym) - setter! = setu(prob, sym) + getter = getsym(prob, sym) + setter! = setsym(prob, sym) if check_inference @inferred getter(prob) end @@ -818,7 +818,7 @@ end ([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false), ((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) if check_inference @inferred getter(sol) end @@ -853,7 +853,7 @@ end ([2x, 3xd1], [2_xval, 3_xd1val], true), ((2x, 3xd2), (2_xval, 3_xd2val), true) ] - getter = getu(sys, sym) + getter = getsym(sys, sym) @test_throws Exception getter(sol) for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] @test_throws Exception getter(sol, subidx) diff --git a/test/downstream/integrator_indexing.jl b/test/downstream/integrator_indexing.jl index daa52edf4..0932718ff 100644 --- a/test/downstream/integrator_indexing.jl +++ b/test/downstream/integrator_indexing.jl @@ -330,21 +330,21 @@ integrator = init(prob, Tsit5(), save_everystep = false) @test integrator[x] isa Vector{Float64} @test integrator[@nonamespace sys.x] isa Vector{Float64} -getx = getu(integrator, x) -gety = getu(integrator, :y) -get_arr = getu(integrator, [x, y]) -get_tuple = getu(integrator, (x, y)) -get_obs = getu(integrator, x[1] / p[1]) +getx = getsym(integrator, x) +gety = getsym(integrator, :y) +get_arr = getsym(integrator, [x, y]) +get_tuple = getsym(integrator, (x, y)) +get_obs = getsym(integrator, x[1] / p[1]) @test getx(integrator) == [1.0, 2.0, 3.0] @test gety(integrator) == 1.0 @test get_arr(integrator) == [[1.0, 2.0, 3.0], 1.0] @test get_tuple(integrator) == ([1.0, 2.0, 3.0], 1.0) @test get_obs(integrator) == 1.0 -setx! = setu(integrator, x) -sety! = setu(integrator, :y) -set_arr! = setu(integrator, [x, y]) -set_tuple! = setu(integrator, (x, y)) +setx! = setsym(integrator, x) +sety! = setsym(integrator, :y) +set_arr! = setsym(integrator, [x, y]) +set_tuple! = setsym(integrator, (x, y)) setx!(integrator, [4.0, 5.0, 6.0]) @test getx(integrator) == [4.0, 5.0, 6.0] diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index d0df44658..a10dd09c7 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -79,7 +79,7 @@ for (sys, prob) in zip(syss, probs) @inferred typeof(prob) remake(prob) baseType = Base.typename(typeof(prob)).wrapper - ugetter = getu(prob, [x, y, z]) + ugetter = getsym(prob, [x, y, z]) prob2 = @inferred baseType remake(prob; u0 = [x => 2.0, y => 3.0, z => 4.0]) @test ugetter(prob2) == [2.0, 3.0, 4.0] prob2 = @inferred baseType remake(prob; u0 = [sys.x => 2.0, sys.y => 3.0, sys.z => 4.0]) diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index 30588cb0c..b179551dc 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -80,21 +80,21 @@ oprob[sys.y] = 10.0 oprob[:z] = 1.0 @test oprob[z] == oprob[sys.z] == oprob[:z] == 1.0 -getx = getu(oprob, x) -gety = getu(oprob, :y) -get_arr = getu(oprob, [x, y]) -get_tuple = getu(oprob, (y, z)) -get_obs = getu(oprob, sys.x + sys.z + t + σ) +getx = getsym(oprob, x) +gety = getsym(oprob, :y) +get_arr = getsym(oprob, [x, y]) +get_tuple = getsym(oprob, (y, z)) +get_obs = getsym(oprob, sys.x + sys.z + t + σ) @test getx(oprob) == 10.0 @test gety(oprob) == 10.0 @test get_arr(oprob) == [10.0, 10.0] @test get_tuple(oprob) == (10.0, 1.0) @test get_obs(oprob) == 22.0 -setx! = setu(oprob, x) -sety! = setu(oprob, :y) -set_arr! = setu(oprob, [x, y]) -set_tuple! = setu(oprob, (y, z)) +setx! = setsym(oprob, x) +sety! = setsym(oprob, :y) +set_arr! = setsym(oprob, [x, y]) +set_tuple! = setsym(oprob, (y, z)) setx!(oprob, 11.0) @test getx(oprob) == 11.0 @@ -168,21 +168,21 @@ sprob[noise_sys.y] = 10.0 sprob[:z] = 1.0 @test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 1.0 -getx = getu(sprob, x) -gety = getu(sprob, :y) -get_arr = getu(sprob, [x, y]) -get_tuple = getu(sprob, (y, z)) -get_obs = getu(sprob, sys.x + sys.z + t + σ) +getx = getsym(sprob, x) +gety = getsym(sprob, :y) +get_arr = getsym(sprob, [x, y]) +get_tuple = getsym(sprob, (y, z)) +get_obs = getsym(sprob, sys.x + sys.z + t + σ) @test getx(sprob) == 10.0 @test gety(sprob) == 10.0 @test get_arr(sprob) == [10.0, 10.0] @test get_tuple(sprob) == (10.0, 1.0) @test get_obs(sprob) == 22.0 -setx! = setu(sprob, x) -sety! = setu(sprob, :y) -set_arr! = setu(sprob, [x, y]) -set_tuple! = setu(sprob, (y, z)) +setx! = setsym(sprob, x) +sety! = setsym(sprob, :y) +set_arr! = setsym(sprob, [x, y]) +set_tuple! = setsym(sprob, (y, z)) setx!(sprob, 11.0) @test getx(sprob) == 11.0 @@ -228,9 +228,9 @@ eprob = EnsembleProblem(oprob) @test eprob.ps[p] == 1.0 @test eprob.ps[:p] == 1.0 @test eprob.ps[osys.p] == 1.0 -@test getu(eprob, X)(eprob) == 0.1 -@test getu(eprob, :X)(eprob) == 0.1 -@test getu(eprob, osys.X)(eprob) == 0.1 +@test getsym(eprob, X)(eprob) == 0.1 +@test getsym(eprob, :X)(eprob) == 0.1 +@test getsym(eprob, osys.X)(eprob) == 0.1 @test getp(eprob, p)(eprob) == 1.0 @test getp(eprob, :p)(eprob) == 1.0 @test getp(eprob, osys.p)(eprob) == 1.0 @@ -247,11 +247,11 @@ eprob = EnsembleProblem(oprob) @test eprob.ps[:p] == 0.1 @test_nowarn eprob.ps[osys.p] = 0.0 @test eprob.ps[osys.p] == 0.0 -@test_nowarn setu(eprob, X)(eprob, 0.1) +@test_nowarn setsym(eprob, X)(eprob, 0.1) @test eprob[X] == 0.1 -@test_nowarn setu(eprob, :X)(eprob, 0.0) +@test_nowarn setsym(eprob, :X)(eprob, 0.0) @test eprob[:X] == 0.0 -@test_nowarn setu(eprob, osys.X)(eprob, 0.1) +@test_nowarn setsym(eprob, osys.X)(eprob, 0.1) @test eprob[osys.X] == 0.1 @test_nowarn setp(eprob, p)(eprob, 0.1) @test eprob.ps[p] == 0.1 @@ -285,9 +285,9 @@ prob = SteadyStateProblem(osys, u0, ps) @test prob[X] == prob[osys.X] == prob[:X] == 0.1 @test prob[X2] == prob[osys.X2] == prob[:X2] == 0.2 @test prob[[X, X2]] == prob[[osys.X, osys.X2]] == prob[[:X, :X2]] == [0.1, 0.2] -@test getu(prob, X)(prob) == getu(prob, osys.X)(prob) == getu(prob, :X)(prob) == 0.1 -@test getu(prob, X2)(prob) == getu(prob, osys.X2)(prob) == getu(prob, :X2)(prob) == 0.2 -@test getu(prob, [X, X2])(prob) == getu(prob, [osys.X, osys.X2])(prob) == - getu(prob, [:X, :X2])(prob) == [0.1, 0.2] -@test getu(prob, (X, X2))(prob) == getu(prob, (osys.X, osys.X2))(prob) == - getu(prob, (:X, :X2))(prob) == (0.1, 0.2) +@test getsym(prob, X)(prob) == getsym(prob, osys.X)(prob) == getsym(prob, :X)(prob) == 0.1 +@test getsym(prob, X2)(prob) == getsym(prob, osys.X2)(prob) == getsym(prob, :X2)(prob) == 0.2 +@test getsym(prob, [X, X2])(prob) == getsym(prob, [osys.X, osys.X2])(prob) == + getsym(prob, [:X, :X2])(prob) == [0.1, 0.2] +@test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) == + getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)